summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs33
1 files changed, 33 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index aa157a2f..1815dd32 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1652,6 +1652,39 @@ pub fn call_im2col1d_strided(
}
#[allow(clippy::too_many_arguments)]
+pub fn call_col2im1d(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ k_size: usize,
+ stride: usize,
+ input: BufferOffset,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
+ let l_in = shape[1];
+ let c_out = shape[2];
+ let l_out = (l_in - 1) * stride + k_size;
+ let dst_el = shape[0] * c_out * l_out;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+ encoder.set_compute_pipeline_state(&pipeline);
+ set_params!(
+ encoder,
+ (dst_el, l_out, l_in, c_out, k_size, stride, &input, output)
+ );
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
pub fn call_im2col_strided(
device: &Device,
command_buffer: &CommandBufferRef,