diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-03-21 13:08:45 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-21 18:08:45 +0100 |
commit | 9563a5fee42f8fef754c238e28ca79725813cea1 (patch) | |
tree | 15f8e7bdc192b04da1e4ac7d32a85cf7c912cabb /candle-metal-kernels/src/lib.rs | |
parent | ec97c98e81707c8f66db6be22d2df7c8791c55b8 (diff) | |
download | candle-9563a5fee42f8fef754c238e28ca79725813cea1.tar.gz candle-9563a5fee42f8fef754c238e28ca79725813cea1.tar.bz2 candle-9563a5fee42f8fef754c238e28ca79725813cea1.zip |
Add support for conv_transpose2d on Metal backend (#1903)
* add support for conv transpose 2d and add bench mark for float types
* update bench calculation
* enable testing all conv operations on metal
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index bab44a05..f2c9c7fe 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1970,5 +1970,63 @@ pub fn call_conv_transpose1d( Ok(()) } +pub struct CallConvTranspose2dCfg<'a> { + pub dilation: usize, + pub stride: usize, + pub padding: usize, + pub output_padding: usize, + pub c_out: usize, + pub out_w: usize, + pub out_h: usize, + pub b_size: usize, + pub input_dims: &'a [usize], + pub input_stride: &'a [usize], + pub kernel_dims: &'a [usize], + pub kernel_stride: &'a [usize], + pub input_offset: usize, + pub kernel_offset: usize, +} + +pub fn call_conv_transpose2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + cfg: CallConvTranspose2dCfg, + input: &Buffer, + kernel: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + cfg.out_w, + cfg.out_h, + cfg.stride, + cfg.padding, + cfg.output_padding, + cfg.dilation, + cfg.input_dims, + cfg.input_stride, + cfg.kernel_dims, + cfg.kernel_stride, + (input, cfg.input_offset), + (kernel, cfg.kernel_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(kernel, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; |