diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-24 12:07:31 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-24 12:07:31 +0100 |
commit | ca318a6ec7ab07a21b4b90727cb42a7242271b4c (patch) | |
tree | 5144a2503689c06a695070b198b9894d5a22ccfb /candle-kernels/src | |
parent | dd64465899f4b58628642b406c465d35ddfe8f79 (diff) | |
download | candle-ca318a6ec7ab07a21b4b90727cb42a7242271b4c.tar.gz candle-ca318a6ec7ab07a21b4b90727cb42a7242271b4c.tar.bz2 candle-ca318a6ec7ab07a21b4b90727cb42a7242271b4c.zip |
Add to the cuda example a reproduction of the issue. (#579)
* Add to the cuda example a reproduction of the issue.
* Tweak.
* Add a test using non-square matrixes.
* Fix the conv2d kernel.
* Display the error.
* And tweak the comment.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r-- | candle-kernels/src/conv.cu | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index afda7d1d..19d94385 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -64,18 +64,18 @@ __device__ void conv2d( T *dst ) { const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; - // src: (b_size, c_in, w_in, h_in) - // k: (c_out, c_in, w_k, h_k) + // src: (b_size, c_in, h_in, w_in) + // k: (c_out, c_in, h_k, w_k) const size_t *src_dims = info; const size_t *src_s = info + 4; const size_t *k_dims = info + 8; const size_t *k_s = info + 12; - const size_t w_k = k_dims[2]; - const size_t h_k = k_dims[3]; + const size_t h_k = k_dims[2]; + const size_t w_k = k_dims[3]; const size_t c_out = k_dims[0]; const size_t c_in = src_dims[1]; - const size_t w_in = src_dims[2]; - const size_t h_in = src_dims[3]; + const size_t h_in = src_dims[2]; + const size_t w_in = src_dims[3]; if (dst_i >= src_dims[0] * c_out * w_out * h_out) { return; } @@ -83,8 +83,9 @@ __device__ void conv2d( // TODO const size_t b_idx = dst_i / (w_out * h_out * c_out); const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out; - const size_t dst_w = (dst_i / h_out) % w_out; - const size_t dst_h = dst_i % h_out; + // NCHW layout. + const size_t dst_h = (dst_i / w_out) % h_out; + const size_t dst_w = dst_i % w_out; const size_t src_idx0 = b_idx * src_s[0]; A d = 0; @@ -101,8 +102,8 @@ __device__ void conv2d( } src_h -= padding; for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { - const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; - const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + w_offset * k_s[2] + h_offset * k_s[3]; + const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_h * src_s[2] + src_w * src_s[3]; + const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + h_offset * k_s[2] + w_offset * k_s[3]; d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]); } } |