diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-03-18 13:50:14 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-18 18:50:14 +0100 |
commit | 04a61a9c72a1f13546c8b7becd95055129fda22f (patch) | |
tree | 3aba0534f7a1b974002dff2595f0c0c7001f1822 /candle-metal-kernels | |
parent | 58605252e8c9355d6f2452f54918e9eb4b938b1f (diff) | |
download | candle-04a61a9c72a1f13546c8b7becd95055129fda22f.tar.gz candle-04a61a9c72a1f13546c8b7becd95055129fda22f.tar.bz2 candle-04a61a9c72a1f13546c8b7becd95055129fda22f.zip |
Add avg_pool2d metal implementation for the metal backend (#1869)
* implement metal avg pool 2d
* fixX
* add suggested precision workaround for the accumulator
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/conv.metal | 69 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 136 |
3 files changed, 194 insertions, 13 deletions
diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal index d7c23ddf..7f7a75cf 100644 --- a/candle-metal-kernels/src/conv.metal +++ b/candle-metal-kernels/src/conv.metal @@ -206,6 +206,67 @@ kernel void FN_NAME( \ upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \ } \ +template <typename T, typename A> +METAL_FUNC void avg_pool2d( + constant size_t &w_k, + constant size_t &h_k, + constant size_t &w_stride, + constant size_t &h_stride, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (tid >= src_dims[0] * c * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c); + const size_t c_idx = (tid / (w_out * h_out)) % c; + const size_t dst_w = (tid / h_out) % w_out; + const size_t dst_h = tid % h_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + A d = 0; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in){ + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3]; + d += static_cast<A>(src[src_idx]); + } + } + dst[tid] = static_cast<T>(d / (w_k * h_k)); +} + +#define AVGPOOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_k, \ + constant size_t &h_k, \ + constant size_t &w_s, \ + constant size_t &h_s, \ + constant size_t *src_dims, \ + constant size_t *src_s, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + avg_pool2d<TYPENAME, TYPEACC>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ +} \ + template <typename T> METAL_FUNC void max_pool2d( constant size_t &w_k, @@ -292,4 +353,12 @@ MAXPOOL2D_OP(uint32_t, max_pool2d_u32) MAXPOOL2D_OP(uint8_t, max_pool2d_u8) #if defined(__HAVE_BFLOAT__) MAXPOOL2D_OP(bfloat, max_pool2d_bf16) +#endif + +AVGPOOL2D_OP(float, float, avg_pool2d_f32) +AVGPOOL2D_OP(half, float, avg_pool2d_f16) +AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32) +AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8) +#if defined(__HAVE_BFLOAT__) +AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16) #endif
\ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index b1830a25..1161501f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1827,7 +1827,7 @@ fn divide(m: usize, b: usize) -> NSUInteger { } #[allow(clippy::too_many_arguments)] -pub fn call_max_pool2d( +pub fn call_pool2d( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 74721153..19e160dd 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1369,7 +1369,7 @@ fn index_add() { } } -fn run_max_pool2d<T: Clone>( +fn run_pool2d<T: Clone>( v: &[T], (w_k, h_k): (usize, usize), (w_stride, h_stride): (usize, usize), @@ -1386,7 +1386,7 @@ fn run_max_pool2d<T: Clone>( let input = new_buffer(&device, v); let output = new_buffer(&device, &vec![0.0f32; dst_el]); let kernels = Kernels::new(); - call_max_pool2d( + call_pool2d( &device, command_buffer, &kernels, @@ -1417,7 +1417,7 @@ fn max_pool2d_f32() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 1; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1434,7 +1434,7 @@ fn max_pool2d_f32() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 2; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1454,7 +1454,7 @@ fn max_pool2d_f16() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 1; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1474,7 +1474,7 @@ fn max_pool2d_f16() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 2; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1497,7 +1497,7 @@ fn max_pool2d_bf16() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 1; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1517,7 +1517,7 @@ fn max_pool2d_bf16() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 2; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1540,7 +1540,7 @@ fn max_pool2d_u8() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 1; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1557,7 +1557,7 @@ fn max_pool2d_u8() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 2; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1577,7 +1577,7 @@ fn max_pool2d_u32() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 1; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1594,7 +1594,7 @@ fn max_pool2d_u32() { let strides = vec![16, 16, 4, 1]; let kernel = 2; let stride = 2; - let results = run_max_pool2d( + let results = run_pool2d( &v, (kernel, kernel), (stride, stride), @@ -1605,3 +1605,115 @@ fn max_pool2d_u32() { let expected = vec![5, 7, 13, 15]; assert_eq!(results, expected); } + +#[test] +fn avg_pool2d_f32() { + // kernel 2 stride 1 + let v: Vec<f32> = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_f32", + ); + let expected = vec![ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ]; + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_f16() { + // kernel 2 stride 1 + let v: Vec<f16> = (0..16).map(|v| f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_f16", + ); + let expected = vec![ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::<Vec<_>>(); + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_bf16() { + // kernel 2 stride 1 + let v: Vec<bf16> = (0..16).map(|v| bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_bf16", + ); + let expected = vec![ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::<Vec<_>>(); + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_u8() { + // kernel 2 stride 1 + let v: Vec<u8> = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_u8", + ); + let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12]; + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_u32() { + // kernel 2 stride 1 + let v: Vec<u32> = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_u32", + ); + let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12]; + assert_eq!(results, expected); +} |