diff options
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index aa157a2f..1815dd32 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1652,6 +1652,39 @@ pub fn call_im2col1d_strided( } #[allow(clippy::too_many_arguments)] +pub fn call_col2im1d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + k_size: usize, + stride: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + let l_in = shape[1]; + let c_out = shape[2]; + let l_out = (l_in - 1) * stride + k_size; + let dst_el = shape[0] * c_out * l_out; + + let encoder = command_buffer.new_compute_command_encoder(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (dst_el, l_out, l_in, c_out, k_size, stride, &input, output) + ); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] pub fn call_im2col_strided( device: &Device, command_buffer: &CommandBufferRef, |