diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-03-18 03:33:30 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-18 08:33:30 +0100 |
commit | 754fa1e8134dd78c841c936eca746de9408e9ea7 (patch) | |
tree | ac4233588d2758954fb55487f73af8dff9a73cf4 /candle-metal-kernels/src/lib.rs | |
parent | 184105792f1d5c70ac07da4832938f3963c740dc (diff) | |
download | candle-754fa1e8134dd78c841c936eca746de9408e9ea7.tar.gz candle-754fa1e8134dd78c841c936eca746de9408e9ea7.tar.bz2 candle-754fa1e8134dd78c841c936eca746de9408e9ea7.zip |
Add support for max_pool2d for Metal backend (#1863)
* first pass at implementation of maxpool2d
* Add definitions for other dtypes
* add tests for other dtypes
* Cosmetic tweaks + re-enable maxpool2d tests for metal.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a879c86a..b1830a25 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1826,5 +1826,38 @@ fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } +#[allow(clippy::too_many_arguments)] +pub fn call_max_pool2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + out_w: usize, + out_h: usize, + w_k: usize, + h_k: usize, + w_stride: usize, + h_stride: usize, + input: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = out_w * out_h * shape[0] * shape[1]; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (w_k, h_k, w_stride, h_stride, shape, strides, input, output) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; |