diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-03-18 03:33:30 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-18 08:33:30 +0100 |
commit | 754fa1e8134dd78c841c936eca746de9408e9ea7 (patch) | |
tree | ac4233588d2758954fb55487f73af8dff9a73cf4 /candle-metal-kernels | |
parent | 184105792f1d5c70ac07da4832938f3963c740dc (diff) | |
download | candle-754fa1e8134dd78c841c936eca746de9408e9ea7.tar.gz candle-754fa1e8134dd78c841c936eca746de9408e9ea7.tar.bz2 candle-754fa1e8134dd78c841c936eca746de9408e9ea7.zip |
Add support for max_pool2d for Metal backend (#1863)
* first pass at implementation of maxpool2d
* Add definitions for other dtypes
* add tests for other dtypes
* Cosmetic tweaks + re-enable maxpool2d tests for metal.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/conv.metal | 82 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 33 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 239 |
3 files changed, 353 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal index dca53161..d7c23ddf 100644 --- a/candle-metal-kernels/src/conv.metal +++ b/candle-metal-kernels/src/conv.metal @@ -1,3 +1,9 @@ +#include <metal_stdlib> + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + template <typename T> METAL_FUNC void im2col( constant size_t &dst_numel, @@ -200,6 +206,74 @@ kernel void FN_NAME( \ upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \ } \ +template <typename T> +METAL_FUNC void max_pool2d( + constant size_t &w_k, + constant size_t &h_k, + constant size_t &w_stride, + constant size_t &h_stride, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (tid >= src_dims[0] * c * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c); + const size_t c_idx = (tid / (w_out * h_out)) % c; + const size_t dst_w = (tid / h_out) % w_out; + const size_t dst_h = tid % h_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + T d = 0; + bool set = false; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in){ + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3]; + if (set) { + d = MAX(d, src[src_idx]); + } + else { + d = src[src_idx]; + set = true; + } + } + } + dst[tid] = d; +} + +#define MAXPOOL2D_OP(TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_k, \ + constant size_t &h_k, \ + constant size_t &w_s, \ + constant size_t &h_s, \ + constant size_t *src_dims, \ + constant size_t *src_s, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + max_pool2d<TYPENAME>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ +} \ + IM2COL_OP(float, im2col_f32) IM2COL_OP(uint8_t, im2col_u8) IM2COL_OP(uint32_t, im2col_u32) @@ -211,3 +285,11 @@ IM2COL1D_OP(uint32_t, im2col1d_u32) UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32) UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) + +MAXPOOL2D_OP(float, max_pool2d_f32) +MAXPOOL2D_OP(half, max_pool2d_f16) +MAXPOOL2D_OP(uint32_t, max_pool2d_u32) +MAXPOOL2D_OP(uint8_t, max_pool2d_u8) +#if defined(__HAVE_BFLOAT__) +MAXPOOL2D_OP(bfloat, max_pool2d_bf16) +#endif
\ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a879c86a..b1830a25 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1826,5 +1826,38 @@ fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } +#[allow(clippy::too_many_arguments)] +pub fn call_max_pool2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + out_w: usize, + out_h: usize, + w_k: usize, + h_k: usize, + w_stride: usize, + h_stride: usize, + input: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = out_w * out_h * shape[0] * shape[1]; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (w_k, h_k, w_stride, h_stride, shape, strides, input, output) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index a34882d3..74721153 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,6 @@ use super::*; use half::{bf16, f16}; -use metal::{Buffer, Device, MTLResourceOptions}; +use metal::MTLResourceOptions; fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { let ptr = buffer.contents() as *const T; @@ -1368,3 +1368,240 @@ fn index_add() { assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); } } + +fn run_max_pool2d<T: Clone>( + v: &[T], + (w_k, h_k): (usize, usize), + (w_stride, h_stride): (usize, usize), + shape: &[usize], + strides: &[usize], + name: &'static str, +) -> Vec<T> { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let out_w = (shape[2] - w_k) / w_stride + 1; + let out_h = (shape[3] - h_k) / h_stride + 1; + let dst_el = out_w * out_h * shape[0] * shape[1]; + let input = new_buffer(&device, v); + let output = new_buffer(&device, &vec![0.0f32; dst_el]); + let kernels = Kernels::new(); + call_max_pool2d( + &device, + command_buffer, + &kernels, + name, + shape, + strides, + out_w, + out_h, + w_k, + h_k, + w_stride, + h_stride, + &input, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, dst_el) +} + +#[test] +fn max_pool2d_f32() { + // kernel 2 stride 1 + let v: Vec<f32> = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f32", + ); + let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec<f32> = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f32", + ); + let expected = vec![5.0, 7.0, 13.0, 15.0]; + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_f16() { + // kernel 2 stride 1 + let v: Vec<half::f16> = (0..16).map(|v| half::f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f16", + ); + let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + .iter() + .map(|v| half::f16::from_f32(*v)) + .collect::<Vec<_>>(); + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec<half::f16> = (0..16).map(|v| half::f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f16", + ); + let expected = vec![5.0, 7.0, 13.0, 15.0] + .iter() + .map(|v| half::f16::from_f32(*v)) + .collect::<Vec<_>>(); + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_bf16() { + // kernel 2 stride 1 + let v: Vec<half::bf16> = (0..16).map(|v| half::bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_bf16", + ); + let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + .iter() + .map(|v| half::bf16::from_f32(*v)) + .collect::<Vec<_>>(); + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec<half::bf16> = (0..16).map(|v| half::bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_bf16", + ); + let expected = vec![5.0, 7.0, 13.0, 15.0] + .iter() + .map(|v| half::bf16::from_f32(*v)) + .collect::<Vec<_>>(); + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_u8() { + // kernel 2 stride 1 + let v: Vec<u8> = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u8", + ); + let expected = vec![5, 6, 7, 9, 10, 11, 13, 14, 15]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec<u8> = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u8", + ); + let expected = vec![5, 7, 13, 15]; + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_u32() { + // kernel 2 stride 1 + let v: Vec<u32> = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u32", + ); + let expected = vec![5, 6, 7, 9, 10, 11, 13, 14, 15]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec<u32> = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_max_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u32", + ); + let expected = vec![5, 7, 13, 15]; + assert_eq!(results, expected); +} |