summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend.rs66
-rw-r--r--candle-core/tests/conv_tests.rs128
-rw-r--r--candle-kernels/src/conv.cu88
3 files changed, 206 insertions, 76 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 75eaf70a..ed696368 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -977,8 +977,8 @@ impl<'a> Map2 for Conv2D<'a> {
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
- // Kernel shape: (c_out, c_in_k, w_k, h_k)
- // Input shape: (b_size, c_in, w_in, c_in)
+ // Kernel shape: (c_out, c_in_k, h_k, w_k)
+ // Input shape: (b_size, c_in, h_in, w_in)
let p = &self.0;
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
@@ -1005,6 +1005,55 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
+struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
+impl<'a> Map2 for ConvTranspose2D<'a> {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ inp: &CudaSlice<T>,
+ inp_l: &Layout,
+ k: &CudaSlice<T>,
+ k_l: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ // Kernel shape: (c_in_k, c_out, h_k, w_k)
+ // Input shape: (b_size, c_in, h_in, w_in)
+ let p = &self.0;
+ let (out_w, out_h) = (p.out_w(), p.out_h());
+ let dst_el = p.c_out * out_w * out_h * p.b_size;
+ let inp = &inp.slice(inp_l.start_offset()..);
+ let k = &k.slice(k_l.start_offset()..);
+ let shape = inp_l.shape();
+ let dims = shape.dims();
+ let el = shape.elem_count();
+
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
+ let ds = if dims.len() == 4 {
+ [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
+ } else {
+ crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
+ };
+ let ds = dev.htod_copy(ds).w()?;
+ let params = (
+ el,
+ out_w,
+ out_h,
+ p.stride,
+ p.padding,
+ p.output_padding,
+ &ds,
+ inp,
+ k,
+ &out,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(out)
+ }
+}
+
enum PoolOp {
Max,
Avg,
@@ -1649,12 +1698,15 @@ impl BackendStorage for CudaStorage {
fn conv_transpose2d(
&self,
- _l: &Layout,
- _kernel: &Self,
- _kernel_l: &Layout,
- _params: &crate::conv::ParamsConvTranspose2D,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
- todo!()
+ let device = self.device().clone();
+ let slice =
+ ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ Ok(Self { slice, device })
}
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 4fe76378..1c378e5e 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -122,33 +122,31 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
- if dev.is_cpu() {
- let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
- assert_eq!(res.dims(), [1, 2, 7, 7]);
- assert_eq!(
- test_utils::to_vec3_round(&res.i(0)?, 4)?,
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
+ assert_eq!(res.dims(), [1, 2, 7, 7]);
+ assert_eq!(
+ test_utils::to_vec3_round(&res.i(0)?, 4)?,
+ [
[
- [
- [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
- [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
- [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
- [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
- [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
- [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
- [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
- ],
- [
- [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
- [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
- [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
- [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
- [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
- [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
- [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
- ]
+ [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
+ [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
+ [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
+ [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
+ [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
+ [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
+ [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
+ ],
+ [
+ [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
+ [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
+ [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
+ [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
+ [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
+ [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
+ [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
]
- );
- }
+ ]
+ );
Ok(())
}
@@ -202,26 +200,23 @@ fn conv2d_small(dev: &Device) -> Result<()> {
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
]
);
- // TODO: enable the test for cuda once we have the proper implementation in place.
- if dev.is_cpu() {
- let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
- assert_eq!(res.dims(), [1, 1, 3, 3]);
- assert_eq!(
- test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
- [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
- );
- let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?;
- assert_eq!(res.dims(), [2, 2, 3, 3]);
- assert_eq!(
- test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
- [
- -0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728,
- 0.528, -1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838,
- 0.5802, -0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396,
- -0.8156, 0.4594, 2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
- ]
- );
- }
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
+ assert_eq!(res.dims(), [1, 1, 3, 3]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
+ );
+ let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?;
+ assert_eq!(res.dims(), [2, 2, 3, 3]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [
+ -0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, 0.528,
+ -1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, 0.5802,
+ -0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, -0.8156, 0.4594,
+ 2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
+ ]
+ );
Ok(())
}
@@ -275,10 +270,8 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
Ok(())
}
-#[test]
-fn conv2d_grad() -> Result<()> {
+fn conv2d_grad(dev: &Device) -> Result<()> {
use candle_core::Var;
- let dev = &Device::Cpu;
let t = Var::from_slice(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
@@ -318,32 +311,28 @@ fn conv2d_grad() -> Result<()> {
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
- test_utils::to_vec1_round(&grad_t.flatten_all()?, 4)?,
+ test_utils::to_vec1_round(&grad_t.flatten_all()?, 2)?,
[
- 9.2868, -2.8352, -5.7117, 3.3817, -7.7094, -19.1549, 7.016, 29.1037, 9.3411, 34.7339,
- -22.8726, 24.3502, -39.88, -14.007, 21.076, 9.9419, 13.6333, -34.6796, 11.2073,
- -6.2617, 7.7209, -6.3224, -16.6373, -1.0837, -20.2215, 21.7302, -0.3744, -4.0573,
- 5.8163, -3.6529, -30.7319, 14.5468, 87.699, 31.6035, 4.5304, -89.785, -75.3709,
- -57.4327, -7.5602, 92.9585, 18.791, -4.6311, -159.7521, -42.4656, -47.2644, 52.8768,
- 37.3172, 48.9978, 12.8192, 2.014, -8.9826, 20.1759, 16.621, 12.0599, 15.3849, 19.9979,
- 2.5725, -15.2197, 72.6244, -10.7496, 2.2541, -31.2003, 3.753, -0.2049, 9.7574, -0.6824,
- 5.2107, -40.4361, -22.5891, -61.6085, 17.2837, 20.4149, 37.5454, 5.2262, 6.8126,
- 23.5361, 23.6173, -9.9866, -9.1324, 4.8664, -35.0617, -26.1023, 63.4757, 25.8144,
- -39.2069, -70.6834, -46.9565, 2.3252, 41.8093, 82.4205, -28.626, -11.7812, -35.3284,
- -10.2771, -28.5694, -9.1258, 7.213, -9.0459, -9.6222, -11.2544
+ 9.29, -2.84, -5.71, 3.38, -7.71, -19.15, 7.02, 29.1, 9.34, 34.73, -22.87, 24.35,
+ -39.88, -14.01, 21.08, 9.94, 13.63, -34.68, 11.21, -6.26, 7.72, -6.32, -16.64, -1.08,
+ -20.22, 21.73, -0.37, -4.06, 5.82, -3.65, -30.73, 14.55, 87.7, 31.6, 4.53, -89.78,
+ -75.37, -57.43, -7.56, 92.96, 18.79, -4.63, -159.75, -42.47, -47.26, 52.88, 37.32,
+ 49.0, 12.82, 2.01, -8.98, 20.18, 16.62, 12.06, 15.38, 20.0, 2.57, -15.22, 72.62,
+ -10.75, 2.25, -31.2, 3.75, -0.2, 9.76, -0.68, 5.21, -40.44, -22.59, -61.61, 17.28,
+ 20.41, 37.55, 5.23, 6.81, 23.54, 23.62, -9.99, -9.13, 4.87, -35.06, -26.1, 63.48,
+ 25.81, -39.21, -70.68, -46.96, 2.33, 41.81, 82.42, -28.63, -11.78, -35.33, -10.28,
+ -28.57, -9.13, 7.21, -9.05, -9.62, -11.25
]
);
assert_eq!(
- test_utils::to_vec1_round(&grad_w.flatten_all()?, 4)?,
+ test_utils::to_vec1_round(&grad_w.flatten_all()?, 2)?,
[
- -28.9232, -22.8833, -141.2296, 73.3462, 61.074, 47.8125, -20.0013, -73.7086, -41.8217,
- -13.5919, 21.501, 28.7179, 28.5683, -46.8486, -90.1874, 143.6107, 16.6764, 7.4259,
- 18.8794, -90.8122, -20.2865, 54.7909, 82.6287, 22.943, 77.8084, -16.3928, -13.1977,
- 9.3442, -40.3869, -26.6153, 5.3344, -60.9081, 9.0869, -59.368, 7.081, 58.6391, 5.5476,
- 20.5152, 2.4985, -17.2466, -6.802, 22.2146, 30.1511, -7.5179, -37.4588, 5.6654,
- 22.5832, 9.0316, 47.0547, 17.6123, 37.3121, -98.1295, -14.6141, -4.7958, -6.3597,
- 44.6949, 23.3418, 8.3728, -13.52, 80.0522, -34.2403, -16.3648, -12.3139, 1.9195,
- -33.6244, -14.102, -49.2305, -7.3853, 11.4995, -9.9826, 9.6588, 29.6042
+ -28.92, -22.88, -141.23, 73.35, 61.07, 47.81, -20.0, -73.71, -41.82, -13.59, 21.5,
+ 28.72, 28.57, -46.85, -90.19, 143.61, 16.68, 7.43, 18.88, -90.81, -20.29, 54.79, 82.63,
+ 22.94, 77.81, -16.39, -13.2, 9.34, -40.39, -26.62, 5.33, -60.91, 9.09, -59.37, 7.08,
+ 58.64, 5.55, 20.52, 2.5, -17.25, -6.8, 22.21, 30.15, -7.52, -37.46, 5.67, 22.58, 9.03,
+ 47.05, 17.61, 37.31, -98.13, -14.61, -4.8, -6.36, 44.69, 23.34, 8.37, -13.52, 80.05,
+ -34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
]
);
Ok(())
@@ -359,3 +348,4 @@ test_device!(
);
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
+test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index 19d94385..5ccce317 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -111,6 +111,71 @@ __device__ void conv2d(
dst[dst_i] = static_cast<T>(d);
}
+// Naive implementation of conv_transpose2d.
+template <typename T, typename A>
+__device__ void conv_transpose2d(
+ const size_t src_numel,
+ const size_t w_out,
+ const size_t h_out,
+ const size_t stride,
+ const size_t padding,
+ const size_t out_padding,
+ const size_t *info,
+ const T *src,
+ const T *kernel,
+ T *dst
+) {
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ // src: (b_size, c_in, h_in, w_in)
+ // k: (c_in, c_out, h_k, w_k)
+ const size_t *src_dims = info;
+ const size_t *src_s = info + 4;
+ const size_t *k_dims = info + 8;
+ const size_t *k_s = info + 12;
+ const size_t h_k = k_dims[2];
+ const size_t w_k = k_dims[3];
+ const size_t c_out = k_dims[1];
+ const size_t c_in = src_dims[1];
+ const size_t h_in = src_dims[2];
+ const size_t w_in = src_dims[3];
+ if (dst_i >= src_dims[0] * c_out * w_out * h_out) {
+ return;
+ }
+
+ // TODO
+ const size_t b_idx = dst_i / (w_out * h_out * c_out);
+ const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out;
+ // NCHW layout.
+ const size_t out_y = (dst_i / w_out) % h_out;
+ const size_t out_x = dst_i % w_out;
+
+ const size_t src_idx0 = b_idx * src_s[0];
+ A d = 0;
+ for (int k_x = 0; k_x < (int)w_k; ++k_x) {
+ // let out_x = inp_x * p.stride + k_x - p.padding;
+ int inp_x_stride = (int)(out_x + padding) - k_x;
+ if (inp_x_stride < 0 || inp_x_stride % stride) {
+ continue;
+ }
+ int inp_x = inp_x_stride / stride;
+ if (inp_x >= w_in) continue;
+ for (int k_y = 0; k_y < (int)h_k; ++k_y) {
+ int inp_y_stride = (int)(out_y + padding) - k_y;
+ if (inp_y_stride < 0 || inp_y_stride % stride) {
+ continue;
+ }
+ int inp_y = inp_y_stride / stride;
+ if (inp_y >= h_in) continue;
+ for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
+ const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_y * src_s[2] + inp_x * src_s[3];
+ const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_y * k_s[2] + k_x * k_s[3];
+ d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
+ }
+ }
+ }
+ dst[dst_i] = static_cast<T>(d);
+}
+
template <typename T, typename A>
__device__ void avg_pool2d(
const size_t src_numel,
@@ -293,6 +358,22 @@ extern "C" __global__ void FN_NAME( \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
} \
+#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t src_numel, \
+ const size_t w_out, \
+ const size_t h_out, \
+ const size_t stride, \
+ const size_t padding, \
+ const size_t out_padding, \
+ const size_t *info, \
+ const TYPENAME *src, \
+ const TYPENAME *kernel, \
+ TYPENAME *dst \
+) { \
+ conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \
+} \
+
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t src_numel, \
@@ -337,6 +418,7 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
+CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
@@ -345,6 +427,7 @@ UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
#if __CUDA_ARCH__ >= 530
CONV1D_OP(__half, float, conv1d_f16)
CONV2D_OP(__half, float, conv2d_f16)
+CONVT2D_OP(__half, float, conv_transpose2d_f16)
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
@@ -360,6 +443,11 @@ CONV2D_OP(double, double, conv2d_f64)
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
+CONVT2D_OP(float, float, conv_transpose2d_f32)
+CONVT2D_OP(double, double, conv_transpose2d_f64)
+CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)
+CONVT2D_OP(uint32_t, uint32_t, conv_transpose2d_u32)
+
AVG_POOL2D_OP(float, float, avg_pool2d_f32)
AVG_POOL2D_OP(double, double, avg_pool2d_f64)
AVG_POOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)