diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-03-21 13:08:45 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-21 18:08:45 +0100 |
commit | 9563a5fee42f8fef754c238e28ca79725813cea1 (patch) | |
tree | 15f8e7bdc192b04da1e4ac7d32a85cf7c912cabb | |
parent | ec97c98e81707c8f66db6be22d2df7c8791c55b8 (diff) | |
download | candle-9563a5fee42f8fef754c238e28ca79725813cea1.tar.gz candle-9563a5fee42f8fef754c238e28ca79725813cea1.tar.bz2 candle-9563a5fee42f8fef754c238e28ca79725813cea1.zip |
Add support for conv_transpose2d on Metal backend (#1903)
* add support for conv transpose 2d and add bench mark for float types
* update bench calculation
* enable testing all conv operations on metal
-rw-r--r-- | candle-core/benches/bench_main.rs | 3 | ||||
-rw-r--r-- | candle-core/benches/benchmarks/conv_transpose2d.rs | 59 | ||||
-rw-r--r-- | candle-core/benches/benchmarks/mod.rs | 1 | ||||
-rw-r--r-- | candle-core/src/metal_backend.rs | 66 | ||||
-rw-r--r-- | candle-core/tests/conv_tests.rs | 124 | ||||
-rw-r--r-- | candle-metal-kernels/src/conv.metal | 86 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 58 |
7 files changed, 321 insertions, 76 deletions
diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 162e3f2b..9f94b252 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -5,5 +5,6 @@ criterion_main!( benchmarks::affine::benches, benchmarks::matmul::benches, benchmarks::random::benches, - benchmarks::where_cond::benches + benchmarks::where_cond::benches, + benchmarks::conv_transpose2d::benches, ); diff --git a/candle-core/benches/benchmarks/conv_transpose2d.rs b/candle-core/benches/benchmarks/conv_transpose2d.rs new file mode 100644 index 00000000..7b252ec6 --- /dev/null +++ b/candle-core/benches/benchmarks/conv_transpose2d.rs @@ -0,0 +1,59 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run( + x: &Tensor, + k: &Tensor, + padding: usize, + output_padding: usize, + stride: usize, + dilation: usize, +) { + x.conv_transpose2d(k, padding, output_padding, stride, dilation) + .unwrap(); +} + +fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let t = Tensor::arange(0.0f32, 10000.0, device) + .unwrap() + .reshape((1, 4, 50, 50)) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let kernel = Tensor::arange(0.0f32, 100.0, device) + .unwrap() + .reshape((4, 1, 5, 5)) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&t), black_box(&kernel), 1, 0, 1, 2); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32"); + run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16"); + run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index c45effee..a0ffa3eb 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod affine; +pub(crate) mod conv_transpose2d; pub(crate) mod matmul; pub(crate) mod random; pub(crate) mod where_cond; diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index c4245652..4f4162e2 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -2,8 +2,8 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; -use candle_metal_kernels; use candle_metal_kernels::Kernels; +use candle_metal_kernels::{self, CallConvTranspose2dCfg}; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; @@ -1074,12 +1074,66 @@ impl BackendStorage for MetalStorage { fn conv_transpose2d( &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &ParamsConvTranspose2D, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &ParamsConvTranspose2D, ) -> Result<Self> { - crate::bail!("Metal conv_tranpose2d not implemented") + // Kernel shape: (c_in_k, c_out, h_k, w_k) + // Input shape: (b_size, c_in, h_in, w_in) + let (out_w, out_h) = (params.out_w(), params.out_h()); + let dst_el = params.c_out * out_w * out_h * params.b_size; + + let dims = l.dims(); + if dims.len() != 4 { + crate::bail!("unexpected input shape for conv_transpose2d {dims:?}, expected 4") + } + + let k_dims = kernel_l.dims(); + if k_dims.len() != 4 { + crate::bail!("unexpected kernel shape for conv_transpose2d {k_dims:?}, expected 4") + } + + let buffer = self + .device + .new_buffer(dst_el, self.dtype, "conv_transpose2d")?; + + let command_buffer = self.device.command_buffer()?; + + let name = match self.dtype { + DType::F32 => "conv_transpose2d_f32", + DType::F16 => "conv_transpose2d_f16", + DType::BF16 => "conv_transpose2d_bf16", + dtype => crate::bail!("Metal conv_transpose2d {dtype:?} not implemented"), + }; + + candle_metal_kernels::call_conv_transpose2d( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + CallConvTranspose2dCfg { + dilation: params.dilation, + stride: params.stride, + padding: params.padding, + output_padding: params.output_padding, + c_out: params.c_out, + out_h: out_h, + out_w: out_w, + b_size: params.b_size, + input_dims: l.dims(), + input_stride: l.stride(), + kernel_dims: kernel_l.dims(), + kernel_stride: kernel_l.stride(), + input_offset: l.start_offset() * self.dtype.size_in_bytes(), + kernel_offset: kernel_l.start_offset() * kernel.dtype.size_in_bytes(), + }, + &self.buffer, + &kernel.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) } fn avg_pool2d( diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 71bf65be..6cc48ec7 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -163,33 +163,34 @@ fn conv2d(dev: &Device) -> Result<()> { 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 ] ); - if !dev.is_metal() { - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; - assert_eq!(res.dims(), [1, 2, 7, 7]); - assert_eq!( - test_utils::to_vec3_round(&res.i(0)?, 4)?, + + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; + + assert_eq!(res.dims(), [1, 2, 7, 7]); + assert_eq!( + test_utils::to_vec3_round(&res.i(0)?, 4)?, + [ + [ + [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277], + [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375], + [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889], + [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632], + [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985], + [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114], + [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579] + ], [ - [ - [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277], - [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375], - [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889], - [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632], - [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985], - [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114], - [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579] - ], - [ - [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211], - [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131], - [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621], - [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142], - [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059], - [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516], - [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171] - ] + [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211], + [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131], + [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621], + [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142], + [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059], + [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516], + [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171] ] - ); - } + ] + ); + // Dilations. let res = t.conv2d(&w, 0, 1, 2, 1)?; assert_eq!(res.dims(), [1, 2, 1, 1]); @@ -198,44 +199,37 @@ fn conv2d(dev: &Device) -> Result<()> { [2.45, -2.3504], ); - if !dev.is_metal() { - // Transpose and dilations. - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?; - assert_eq!(res.dims(), [1, 2, 9, 9]); - assert_eq!( - test_utils::to_vec3_round(&res.i(0)?, 4)?, + // Transpose and dilations. + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?; + assert_eq!(res.dims(), [1, 2, 9, 9]); + assert_eq!( + test_utils::to_vec3_round(&res.i(0)?, 4)?, + [ + [ + [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277], + [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499], + [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376], + [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141], + [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822], + [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03], + [-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024], + [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787], + [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579] + ], [ - [ - [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277], - [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499], - [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376], - [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141], - [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822], - [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03], - [ - -2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, - -3.5024 - ], - [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787], - [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579] - ], - [ - [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211], - [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278], - [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861], - [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185], - [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642], - [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957], - [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856], - [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908], - [ - -5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, - 1.0171 - ] - ] + [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211], + [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278], + [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861], + [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185], + [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642], + [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957], + [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856], + [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908], + [-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171] ] - ); - } + ] + ); + Ok(()) } @@ -290,11 +284,6 @@ fn conv2d_small(dev: &Device) -> Result<()> { ] ); - // conv-transposes are not implemented for metal - if dev.is_metal() { - return Ok(()); - } - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!( @@ -397,9 +386,6 @@ print(w.grad[0]) */ fn conv2d_grad(dev: &Device) -> Result<()> { // conv-transposes are not implemented for metal - if dev.is_metal() { - return Ok(()); - } use candle_core::Var; let t = Var::from_slice( &[ diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal index a258ae58..e28ac6b3 100644 --- a/candle-metal-kernels/src/conv.metal +++ b/candle-metal-kernels/src/conv.metal @@ -405,6 +405,86 @@ kernel void FN_NAME( \ conv_transpose1d<TYPENAME, TYPEACC>(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \ } \ +template <typename T, typename A> +METAL_FUNC void conv_transpose2d( + constant size_t &w_out, + constant size_t &h_out, + constant size_t &stride, + constant size_t &padding, + constant size_t &out_padding, + constant size_t &dilation, + constant size_t *input_dims, + constant size_t *input_stride, + constant size_t *k_dims, + constant size_t *k_stride, + device const T *src, + device const T *k, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t h_k = k_dims[2]; + const size_t w_k = k_dims[3]; + const size_t c_out = k_dims[1]; + const size_t c_in = input_dims[1]; + const size_t h_in = input_dims[2]; + const size_t w_in = input_dims[3]; + + if (tid >= input_dims[0] * c_out * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c_out); + const size_t dst_c_idx = (tid / (w_out * h_out)) % c_out; + const size_t out_y = (tid / w_out) % h_out; + const size_t out_x = tid % w_out; + + const size_t src_idx0 = b_idx * input_stride[0]; + + A d = 0; + for (int k_x = 0; k_x < (int)w_k; ++k_x) { + const int inp_x_stride = (int)(out_x + padding) - k_x * dilation; + if (inp_x_stride < 0 || inp_x_stride % stride) { + continue; + } + const int inp_x = inp_x_stride / stride; + if (inp_x >= w_in) continue; + for (int k_y = 0; k_y < (int)h_k; ++k_y) { + const int inp_y_stride = (int)(out_y + padding) - k_y * dilation; + if (inp_y_stride < 0 || inp_y_stride % stride) { + continue; + } + const int inp_y = inp_y_stride / stride; + if (inp_y >= h_in) continue; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * input_stride[1] + inp_y * input_stride[2] + inp_x * input_stride[3]; + const size_t k_idx = src_c_idx * k_stride[0] + dst_c_idx * k_stride[1] + k_y * k_stride[2] + k_x * k_stride[3]; + d += static_cast<A>(src[src_idx]) * static_cast<A>(k[k_idx]); + } + } + } + dst[tid] = static_cast<T>(d); +} + +#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_out, \ + constant size_t &h_out, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &out_padding, \ + constant size_t &dilation, \ + constant size_t *input_dims, \ + constant size_t *input_stride, \ + constant size_t *k_dims, \ + constant size_t *k_stride, \ + device const TYPENAME *src, \ + device const TYPENAME *k, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + conv_transpose2d<TYPENAME, TYPEACC>(w_out, h_out, stride, padding, out_padding, dilation, input_dims, input_stride, k_dims, k_stride, src, k, dst, tid); \ +} \ + IM2COL_OP(float, im2col_f32) IM2COL_OP(uint8_t, im2col_u8) IM2COL_OP(uint32_t, im2col_u32) @@ -439,4 +519,10 @@ CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8) CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32) #if defined(__HAVE_BFLOAT__) CONVT1D_OP(bfloat, float, conv_transpose1d_bf16) +#endif + +CONVT2D_OP(float, float, conv_transpose2d_f32) +CONVT2D_OP(half, float, conv_transpose2d_f16) +#if defined(__HAVE_BFLOAT__) +CONVT1D_OP(bfloat, float, conv_transpose2d_bf16) #endif
\ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index bab44a05..f2c9c7fe 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1970,5 +1970,63 @@ pub fn call_conv_transpose1d( Ok(()) } +pub struct CallConvTranspose2dCfg<'a> { + pub dilation: usize, + pub stride: usize, + pub padding: usize, + pub output_padding: usize, + pub c_out: usize, + pub out_w: usize, + pub out_h: usize, + pub b_size: usize, + pub input_dims: &'a [usize], + pub input_stride: &'a [usize], + pub kernel_dims: &'a [usize], + pub kernel_stride: &'a [usize], + pub input_offset: usize, + pub kernel_offset: usize, +} + +pub fn call_conv_transpose2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + cfg: CallConvTranspose2dCfg, + input: &Buffer, + kernel: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + cfg.out_w, + cfg.out_h, + cfg.stride, + cfg.padding, + cfg.output_padding, + cfg.dilation, + cfg.input_dims, + cfg.input_stride, + cfg.kernel_dims, + cfg.kernel_stride, + (input, cfg.input_offset), + (kernel, cfg.kernel_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(kernel, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; |