summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend.rs62
-rw-r--r--candle-core/tests/conv_tests.rs20
-rw-r--r--candle-kernels/src/conv.cu79
3 files changed, 143 insertions, 18 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 25d73f9b..b7756fa6 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1149,6 +1149,55 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
+struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
+impl<'a> Map2 for ConvTranspose1D<'a> {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ inp: &CudaSlice<T>,
+ inp_l: &Layout,
+ k: &CudaSlice<T>,
+ k_l: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ // Kernel shape: (c_in_k, c_out, l_k)
+ // Input shape: (b_size, c_in, l_in)
+ let p = &self.0;
+ let l_out = p.l_out();
+ let dst_el = p.c_out * l_out * p.b_size;
+ let inp = &inp.slice(inp_l.start_offset()..);
+ let k = &k.slice(k_l.start_offset()..);
+ let shape = inp_l.shape();
+ let dims = shape.dims();
+ let el = shape.elem_count();
+
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
+ let ds = if dims.len() == 3 {
+ [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
+ } else {
+ crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
+ };
+ let ds = dev.htod_copy(ds).w()?;
+ let params = (
+ el,
+ l_out,
+ p.stride,
+ p.padding,
+ p.output_padding,
+ p.dilation,
+ &ds,
+ inp,
+ k,
+ &out,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(out)
+ }
+}
+
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@@ -1810,12 +1859,15 @@ impl BackendStorage for CudaStorage {
fn conv_transpose1d(
&self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: &crate::conv::ParamsConvTranspose1D,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
- todo!()
+ let device = self.device().clone();
+ let slice =
+ ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ Ok(Self { slice, device })
}
#[cfg(not(feature = "cudnn"))]
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 39c6cec0..5bbd903d 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -50,17 +50,15 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
- if dev.is_cpu() {
- let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
- assert_eq!(res.dims(), [1, 2, 7]);
- assert_eq!(
- test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
- [
- 0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
- 4.7076, -5.9745, -0.8276, 1.621
- ],
- );
- }
+ let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
+ assert_eq!(res.dims(), [1, 2, 7]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [
+ 0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
+ 4.7076, -5.9745, -0.8276, 1.621
+ ],
+ );
Ok(())
}
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index 9c8ce00f..fed920f1 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -71,7 +71,6 @@ __device__ void im2col1d(
}
const size_t *src_dims = info;
const size_t *src_s = info + 3;
- const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
@@ -120,7 +119,6 @@ __device__ void im2col(
}
const size_t *src_dims = info;
const size_t *src_s = info + 4;
- const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];
@@ -225,6 +223,60 @@ __device__ void conv2d(
dst[dst_i] = static_cast<T>(d);
}
+// Naive implementation of conv_transpose1d.
+template <typename T, typename A>
+__device__ void conv_transpose1d(
+ const size_t src_numel,
+ const size_t l_out,
+ const size_t stride,
+ const size_t padding,
+ const size_t out_padding,
+ const size_t dilation,
+ const size_t *info,
+ const T *src,
+ const T *kernel,
+ T *dst
+) {
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ // src: (b_size, c_in, l_in)
+ // k: (c_in, c_out, l_k)
+ const size_t *src_dims = info;
+ const size_t *src_s = info + 3;
+ const size_t *k_dims = info + 6;
+ const size_t *k_s = info + 9;
+ const size_t l_k = k_dims[2];
+ const size_t c_out = k_dims[1];
+ const size_t c_in = src_dims[1];
+ const size_t l_in = src_dims[2];
+ if (dst_i >= src_dims[0] * c_out * l_out) {
+ return;
+ }
+
+ // TODO
+ const size_t b_idx = dst_i / (l_out * c_out);
+ const size_t dst_c_idx = (dst_i / l_out) % c_out;
+ // NCL layout.
+ const size_t out_x = dst_i % l_out;
+
+ const size_t src_idx0 = b_idx * src_s[0];
+ A d = 0;
+ for (int k_x = 0; k_x < (int)l_k; ++k_x) {
+ // let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
+ int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
+ if (inp_x_stride < 0 || inp_x_stride % stride) {
+ continue;
+ }
+ int inp_x = inp_x_stride / stride;
+ if (inp_x >= l_in) continue;
+ for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
+ const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_x * src_s[2];
+ const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_x * k_s[2];
+ d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
+ }
+ }
+ dst[dst_i] = static_cast<T>(d);
+}
+
// Naive implementation of conv_transpose2d.
template <typename T, typename A>
__device__ void conv_transpose2d(
@@ -507,6 +559,22 @@ extern "C" __global__ void FN_NAME( \
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
} \
+#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t src_numel, \
+ const size_t l_out, \
+ const size_t stride, \
+ const size_t padding, \
+ const size_t out_padding, \
+ const size_t dilation, \
+ const size_t *info, \
+ const TYPENAME *src, \
+ const TYPENAME *kernel, \
+ TYPENAME *dst \
+) { \
+ conv_transpose1d<TYPENAME, TYPEACC>(src_numel, l_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
+} \
+
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t src_numel, \
@@ -568,6 +636,7 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
+CONVT1D_OP(__nv_bfloat16, float, conv_transpose1d_bf16)
CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
@@ -579,6 +648,7 @@ IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
#if __CUDA_ARCH__ >= 530
CONV1D_OP(__half, float, conv1d_f16)
CONV2D_OP(__half, float, conv2d_f16)
+CONVT1D_OP(__half, float, conv_transpose1d_f16)
CONVT2D_OP(__half, float, conv_transpose2d_f16)
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16)
@@ -597,6 +667,11 @@ CONV2D_OP(double, double, conv2d_f64)
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
+CONVT1D_OP(float, float, conv_transpose1d_f32)
+CONVT1D_OP(double, double, conv_transpose1d_f64)
+CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
+CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
+
CONVT2D_OP(float, float, conv_transpose2d_f32)
CONVT2D_OP(double, double, conv_transpose2d_f64)
CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)