diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-25 11:03:23 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-25 11:03:23 +0200 |
commit | 0814dfd148474f436bb43314fc41639a7b429aab (patch) | |
tree | 7301c78b9e635bc09a268f3293a1f8f8b2997e6f /candle-metal-kernels | |
parent | 3ceca9901a5ebc4ded3ac2cd793d0125f7a12562 (diff) | |
download | candle-0814dfd148474f436bb43314fc41639a7b429aab.tar.gz candle-0814dfd148474f436bb43314fc41639a7b429aab.tar.bz2 candle-0814dfd148474f436bb43314fc41639a7b429aab.zip |
Add a metal kernel for col2im1d. (#2214)
* Add a metal kernel for col2im1d.
* Enable the col2im variant.
* Bugfix.
* Revert the quantized tweak.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/conv.metal | 65 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 33 |
2 files changed, 97 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal index 8fdd0e5f..5348a0f0 100644 --- a/candle-metal-kernels/src/conv.metal +++ b/candle-metal-kernels/src/conv.metal @@ -69,6 +69,50 @@ METAL_FUNC void im2col( } template <typename T> +METAL_FUNC void col2im1d( + constant size_t &dst_el, + constant size_t &l_out, + constant size_t &l_in, + constant size_t &c_out, + constant size_t &k_size, + constant size_t &stride, + device const T *src, + device T *dst, + uint dst_i [[ thread_position_in_grid ]] +) { + // src: (b_size, l_in, c_out, l_k) + // dst: (b_size, c_out, l_out) + if (dst_i >= dst_el) { + return; + } + + const size_t dst_s0 = c_out * l_out; + const size_t dst_s1 = l_out; + const size_t src_s0 = c_out * k_size * l_in; + const size_t src_s1 = c_out * k_size; + const size_t src_s2 = k_size; + + size_t tmp_dst_i = dst_i; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t c_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= c_idx * dst_s1; + const int l_out_idx = tmp_dst_i; + + dst[dst_i] = static_cast<T>(0); + + int l_in_idx = l_out_idx / stride; + int k0 = l_out_idx - l_in_idx * stride; + // l_out_idx = l_in_idx * stride + k0 + for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) { + if (l_in_idx < l_in) { + const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0; + dst[dst_i] += src[src_i]; + } + } +} + +template <typename T> METAL_FUNC void im2col1d( constant size_t &dst_numel, constant size_t &l_out, @@ -190,6 +234,21 @@ kernel void FN_NAME( \ ) { \ im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \ } \ + +#define COL2IM1D_OP(T, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &dst_el, \ + constant size_t &l_out, \ + constant size_t &l_in, \ + constant size_t &c_out, \ + constant size_t &k_size, \ + constant size_t &stride, \ + device const T *src, \ + device T *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + col2im1d<T>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst, tid); \ +} \ #define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \ kernel void FN_NAME( \ @@ -493,6 +552,10 @@ IM2COL_OP(uint32_t, im2col_u32) IM2COL_OP(bfloat, im2col_bf16) #endif +COL2IM1D_OP(float, col2im1d_f32) +COL2IM1D_OP(uint8_t, col2im1d_u8) +COL2IM1D_OP(uint32_t, col2im1d_u32) + IM2COL1D_OP(float, im2col1d_f32) IM2COL1D_OP(uint8_t, im2col1d_u8) IM2COL1D_OP(uint32_t, im2col1d_u32) @@ -533,4 +596,4 @@ CONVT2D_OP(float, float, conv_transpose2d_f32) CONVT2D_OP(half, float, conv_transpose2d_f16) #if defined(__HAVE_BFLOAT__) CONVT1D_OP(bfloat, float, conv_transpose2d_bf16) -#endif
\ No newline at end of file +#endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index aa157a2f..1815dd32 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1652,6 +1652,39 @@ pub fn call_im2col1d_strided( } #[allow(clippy::too_many_arguments)] +pub fn call_col2im1d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + k_size: usize, + stride: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + let l_in = shape[1]; + let c_out = shape[2]; + let l_out = (l_in - 1) * stride + k_size; + let dst_el = shape[0] * c_out * l_out; + + let encoder = command_buffer.new_compute_command_encoder(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (dst_el, l_out, l_in, c_out, k_size, stride, &input, output) + ); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] pub fn call_im2col_strided( device: &Device, command_buffer: &CommandBufferRef, |