summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
authorThomas Santerre <thomas@santerre.xyz>2024-03-18 03:33:30 -0400
committerGitHub <noreply@github.com>2024-03-18 08:33:30 +0100
commit754fa1e8134dd78c841c936eca746de9408e9ea7 (patch)
treeac4233588d2758954fb55487f73af8dff9a73cf4 /candle-metal-kernels/src/lib.rs
parent184105792f1d5c70ac07da4832938f3963c740dc (diff)
downloadcandle-754fa1e8134dd78c841c936eca746de9408e9ea7.tar.gz
candle-754fa1e8134dd78c841c936eca746de9408e9ea7.tar.bz2
candle-754fa1e8134dd78c841c936eca746de9408e9ea7.zip
Add support for max_pool2d for Metal backend (#1863)
* first pass at implementation of maxpool2d * Add definitions for other dtypes * add tests for other dtypes * Cosmetic tweaks + re-enable maxpool2d tests for metal. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs33
1 files changed, 33 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index a879c86a..b1830a25 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1826,5 +1826,38 @@ fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
+#[allow(clippy::too_many_arguments)]
+pub fn call_max_pool2d(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ strides: &[usize],
+ out_w: usize,
+ out_h: usize,
+ w_k: usize,
+ h_k: usize,
+ w_stride: usize,
+ h_stride: usize,
+ input: &Buffer,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let dst_el = out_w * out_h * shape[0] * shape[1];
+ let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+ set_params!(
+ encoder,
+ (w_k, h_k, w_stride, h_stride, shape, strides, input, output)
+ );
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
#[cfg(test)]
mod tests;