diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-03-19 03:46:58 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-19 08:46:58 +0100 |
commit | 2a8679509eb55232b37378442c4366343f6dcb11 (patch) | |
tree | 7fe5881c3441f94d4534e70c1a5ec7c6ead123e2 /candle-metal-kernels | |
parent | 143c481c20abc3420e848eab075d1547a96cc447 (diff) | |
download | candle-2a8679509eb55232b37378442c4366343f6dcb11.tar.gz candle-2a8679509eb55232b37378442c4366343f6dcb11.tar.bz2 candle-2a8679509eb55232b37378442c4366343f6dcb11.zip |
Add support for conv_transpose1d for metal backend (#1874)
* first attempt
* progress
* integrate into metal backend
* finish and get test passing
* add other dtype support
* update transpose1d dtypes supported
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/conv.metal | 78 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 53 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 216 |
3 files changed, 347 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal index 7f7a75cf..a258ae58 100644 --- a/candle-metal-kernels/src/conv.metal +++ b/candle-metal-kernels/src/conv.metal @@ -335,6 +335,76 @@ kernel void FN_NAME( \ max_pool2d<TYPENAME>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ } \ + +// Naive implementation of conv_transpose1d. +template <typename T, typename A> +METAL_FUNC void conv_transpose1d( + constant size_t &l_out, + constant size_t &stride, + constant size_t &padding, + constant size_t &out_padding, + constant size_t &dilation, + constant size_t *src_dims, + constant size_t *src_strides, + constant size_t *k_dims, + constant size_t *k_strides, + device const T *src, + device const T *k, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + // src: (b_size, c_in, l_in) + // kernel: (c_in, c_out, l_k) + const size_t l_k = k_dims[2]; + const size_t c_out = k_dims[1]; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + if (tid >= src_dims[0] * c_out * l_out) { + return; + } + + const size_t b_idx = tid / (l_out * c_out); + const size_t dst_c_idx = (tid / l_out) % c_out; + const size_t out_x = tid % l_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + A d = 0; + for (int k_x = 0; k_x < (int)l_k; ++k_x) { + // let out_x = inp_x * p.stride + k_x * p.dilation - p.padding; + int inp_x_stride = (int)(out_x + padding) - k_x * dilation; + if (inp_x_stride < 0 || inp_x_stride % stride) { + continue; + } + int inp_x = inp_x_stride / stride; + if (inp_x >= l_in) continue; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * src_strides[1] + inp_x * src_strides[2]; + const size_t k_idx = src_c_idx * k_strides[0] + dst_c_idx * k_strides[1] + k_x * k_strides[2]; + d += static_cast<A>(src[src_idx]) * static_cast<A>(k[k_idx]); + } + } + dst[tid] = static_cast<T>(d); +} + +#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &l_out, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &out_padding, \ + constant size_t &dilation, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ + constant size_t *k_dims, \ + constant size_t *k_strides, \ + device const TYPENAME *src, \ + device const TYPENAME *k, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + conv_transpose1d<TYPENAME, TYPEACC>(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \ +} \ + IM2COL_OP(float, im2col_f32) IM2COL_OP(uint8_t, im2col_u8) IM2COL_OP(uint32_t, im2col_u32) @@ -361,4 +431,12 @@ AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32) AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8) #if defined(__HAVE_BFLOAT__) AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16) +#endif + +CONVT1D_OP(float, float, conv_transpose1d_f32) +CONVT1D_OP(half, float, conv_transpose1d_f16) +CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8) +CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32) +#if defined(__HAVE_BFLOAT__) +CONVT1D_OP(bfloat, float, conv_transpose1d_bf16) #endif
\ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1161501f..f12463a4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1859,5 +1859,58 @@ pub fn call_pool2d( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_conv_transpose1d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + dilation: usize, + stride: usize, + padding: usize, + out_padding: usize, + c_out: usize, + l_out: usize, + b_size: usize, + src_shape: &[usize], + src_strides: &[usize], + kernel_shape: &[usize], + kernel_strides: &[usize], + input: &Buffer, + input_offset: usize, + kernel: &Buffer, + kernel_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = c_out * l_out * b_size; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + l_out, + stride, + padding, + out_padding, + dilation, + src_shape, + src_strides, + kernel_shape, + kernel_strides, + (input, input_offset), + (kernel, kernel_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(kernel, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 19e160dd..5045a4a3 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1717,3 +1717,219 @@ fn avg_pool2d_u32() { let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12]; assert_eq!(results, expected); } + +fn run_conv_transpose1d<T: Clone>( + input: &[T], + input_shape: &[usize], + input_stride: &[usize], + kernel: &[T], + kernel_shape: &[usize], + kernel_stride: &[usize], + dilation: usize, + stride: usize, + padding: usize, + out_padding: usize, + name: &'static str, +) -> Vec<T> { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let c_out = kernel_shape[1]; + let k_size = kernel_shape[2]; + let b_size = input_shape[0]; + let l_in = input_shape[2]; + let l_out = (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1; + let dst_el = c_out * l_out * b_size; + + let input = new_buffer(&device, input); + let kernel = new_buffer(&device, kernel); + let output = new_buffer(&device, &vec![0.0f32; dst_el]); + let kernels = Kernels::new(); + + call_conv_transpose1d( + &device, + command_buffer, + &kernels, + name, + dilation, + stride, + padding, + out_padding, + c_out, + l_out, + b_size, + input_shape, + input_stride, + kernel_shape, + kernel_stride, + &input, + 0, + &kernel, + 0, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, dst_el) +} + +#[test] +fn conv_transpose1d_f32() { + let input = vec![1.0f32, 2.0, 3.0, 4.0]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel = vec![1.0f32, 2.0, 3.0, 4.0]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_f32", + ); + + let expected = vec![1., 4., 10., 20., 25., 24., 16.]; + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_f16() { + let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_f16", + ); + + let expected = vec![1., 4., 10., 20., 25., 24., 16.] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::<Vec<_>>(); + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_bf16() { + let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect(); + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect(); + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_bf16", + ); + + let expected = vec![1., 4., 10., 20., 25., 24., 16.] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::<Vec<_>>(); + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_u8() { + let input: Vec<u8> = vec![1, 2, 3, 4]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec<u8> = vec![1, 2, 3, 4]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_u8", + ); + + let expected = vec![1, 4, 10, 20, 25, 24, 16]; + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_u32() { + let input: Vec<u32> = vec![1, 2, 3, 4]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec<u32> = vec![1, 2, 3, 4]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_u32", + ); + + let expected = vec![1, 4, 10, 20, 25, 24, 16]; + assert_eq!(results, expected); +} |