diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-24 12:07:31 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-24 12:07:31 +0100 |
commit | ca318a6ec7ab07a21b4b90727cb42a7242271b4c (patch) | |
tree | 5144a2503689c06a695070b198b9894d5a22ccfb | |
parent | dd64465899f4b58628642b406c465d35ddfe8f79 (diff) | |
download | candle-ca318a6ec7ab07a21b4b90727cb42a7242271b4c.tar.gz candle-ca318a6ec7ab07a21b4b90727cb42a7242271b4c.tar.bz2 candle-ca318a6ec7ab07a21b4b90727cb42a7242271b4c.zip |
Add to the cuda example a reproduction of the issue. (#579)
* Add to the cuda example a reproduction of the issue.
* Tweak.
* Add a test using non-square matrixes.
* Fix the conv2d kernel.
* Display the error.
* And tweak the comment.
-rw-r--r-- | candle-core/examples/cuda_basics.rs | 13 | ||||
-rw-r--r-- | candle-core/tests/conv_tests.rs | 36 | ||||
-rw-r--r-- | candle-kernels/src/conv.cu | 21 |
3 files changed, 58 insertions, 12 deletions
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 6a3aaacc..cbdafd64 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -9,8 +9,17 @@ use candle_core::{Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; - let t = Tensor::rand(-1f32, 1f32, 96, &device)?; - println!("{t}"); + let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?; + let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?; + let out_t = in_t.conv2d(&k_t, 0, 1, 1)?; + println!("{out_t}"); + let in_t = in_t.to_device(&Device::Cpu)?; + let k_t = k_t.to_device(&Device::Cpu)?; + let out_t2 = in_t.conv2d(&k_t, 0, 1, 1)?; + let diff = (out_t.to_device(&Device::Cpu)? - out_t2)? + .sqr()? + .sum_all()?; + println!("{diff}"); let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?; let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?; diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 1177c1cf..7a4c2956 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -183,8 +183,44 @@ fn conv2d_smaller(dev: &Device) -> Result<()> { Ok(()) } +/* This test is based on the following script. +import torch +torch.manual_seed(4242) + +t = torch.randn((1, 2, 4, 2)) +w = torch.randn((1, 2, 1, 1)) +print(t.flatten()) +print(w.flatten()) +res = torch.nn.functional.conv2d(t, w) +print(res.flatten()) +*/ +fn conv2d_non_square(dev: &Device) -> Result<()> { + let t = Tensor::new( + &[ + 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, + 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, + ], + dev, + )?; + let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?; + let t = t.reshape((1, 2, 4, 2))?; + let w = w.reshape((1, 2, 1, 1))?; + let res = t.conv2d(&w, 0, 1, 1)?; + assert_eq!(res.dims(), [1, 1, 4, 2]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [0.2312, 5.2238, 2.3772, 1.9076, 2.0256, -0.5776, -1.6028, -1.467] + ); + Ok(()) +} + test_device!(conv1d, conv1d_cpu, conv1d_gpu); test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu); test_device!(conv2d, conv2d_cpu, conv2d_gpu); +test_device!( + conv2d_non_square, + conv2d_non_square_cpu, + conv2d_non_square_gpu +); test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu); test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu); diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index afda7d1d..19d94385 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -64,18 +64,18 @@ __device__ void conv2d( T *dst ) { const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; - // src: (b_size, c_in, w_in, h_in) - // k: (c_out, c_in, w_k, h_k) + // src: (b_size, c_in, h_in, w_in) + // k: (c_out, c_in, h_k, w_k) const size_t *src_dims = info; const size_t *src_s = info + 4; const size_t *k_dims = info + 8; const size_t *k_s = info + 12; - const size_t w_k = k_dims[2]; - const size_t h_k = k_dims[3]; + const size_t h_k = k_dims[2]; + const size_t w_k = k_dims[3]; const size_t c_out = k_dims[0]; const size_t c_in = src_dims[1]; - const size_t w_in = src_dims[2]; - const size_t h_in = src_dims[3]; + const size_t h_in = src_dims[2]; + const size_t w_in = src_dims[3]; if (dst_i >= src_dims[0] * c_out * w_out * h_out) { return; } @@ -83,8 +83,9 @@ __device__ void conv2d( // TODO const size_t b_idx = dst_i / (w_out * h_out * c_out); const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out; - const size_t dst_w = (dst_i / h_out) % w_out; - const size_t dst_h = dst_i % h_out; + // NCHW layout. + const size_t dst_h = (dst_i / w_out) % h_out; + const size_t dst_w = dst_i % w_out; const size_t src_idx0 = b_idx * src_s[0]; A d = 0; @@ -101,8 +102,8 @@ __device__ void conv2d( } src_h -= padding; for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { - const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; - const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + w_offset * k_s[2] + h_offset * k_s[3]; + const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_h * src_s[2] + src_w * src_s[3]; + const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + h_offset * k_s[2] + w_offset * k_s[3]; d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]); } } |