diff options
-rw-r--r-- | candle-core/src/cpu_backend.rs | 2 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 65 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 276 | ||||
-rw-r--r-- | candle-examples/examples/simple-training/main.rs | 13 | ||||
-rw-r--r-- | candle-kernels/src/cuda_utils.cuh | 3 | ||||
-rw-r--r-- | candle-kernels/src/reduce.cu | 122 |
6 files changed, 453 insertions, 28 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index c39cb9f7..59c17387 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -244,7 +244,7 @@ impl ReduceIndex { val = s } } - dst[unstr_index] = g(val, acc) + dst_to_set[unstr_index] = g(val, acc) } } } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index a88d62c7..6c98cd0a 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -438,6 +438,28 @@ trait Map2InPlace { } } +trait Map1Any { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + wrap: W, + ) -> Result<S>; + + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> { + let out = match s { + S::U8(s) => self.f(s, d, l, S::U8)?, + S::U32(s) => self.f(s, d, l, S::U32)?, + S::BF16(s) => self.f(s, d, l, S::BF16)?, + S::F16(s) => self.f(s, d, l, S::F16)?, + S::F32(s) => self.f(s, d, l, S::F32)?, + S::F64(s) => self.f(s, d, l, S::F64)?, + }; + Ok(out) + } +} + trait Map2Any { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, @@ -574,13 +596,14 @@ impl<'a> Map1 for Sum<'a> { } struct FastReduce<'a>(&'a [usize], ReduceOp); -impl<'a> Map1 for FastReduce<'a> { - fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( +impl<'a> Map1Any for FastReduce<'a> { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( &self, src: &CudaSlice<T>, dev: &CudaDevice, layout: &Layout, - ) -> Result<CudaSlice<T>> { + wrap: W, + ) -> Result<S> { let src_stride = layout.stride(); let src_dims = layout.shape().dims(); let src_el: usize = src_dims.iter().product(); @@ -615,20 +638,32 @@ impl<'a> Map1 for FastReduce<'a> { .htod_copy([dims.as_slice(), stride.as_slice()].concat()) .w()?; let src = &src.slice(layout.start_offset()..); - let name = match self.1 { - ReduceOp::Sum => "fast_sum", - ReduceOp::Min => "fast_min", - ReduceOp::Max => "fast_max", - ReduceOp::ArgMin => "fast_argmin", - ReduceOp::ArgMax => "fast_argmax", + let (name, check_empty, return_index) = match self.1 { + ReduceOp::Sum => ("fast_sum", false, false), + ReduceOp::Min => ("fast_min", true, false), + ReduceOp::Max => ("fast_max", true, false), + ReduceOp::ArgMin => ("fast_argmin", true, true), + ReduceOp::ArgMax => ("fast_argmax", true, true), }; + if check_empty && layout.shape().elem_count() == 0 { + Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + } let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?; - // SAFETY: filled in by the follow up kernel. - let out = unsafe { dev.alloc::<T>(dst_el) }.w()?; - let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; - Ok(out) + if return_index { + // SAFETY: filled in by the follow up kernel. + let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?; + let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(S::U32(out)) + } else { + // SAFETY: filled in by the follow up kernel. + let out = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(wrap(out)) + } } } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index a439ba30..38336ecf 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -164,6 +164,278 @@ fn sum(device: &Device) -> Result<()> { Ok(()) } +fn min(device: &Device) -> Result<()> { + let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.min_keepdim(2)?.to_vec3::<u32>()?, + &[[[1], [1]], [[1], [2]]] + ); + assert_eq!( + tensor.min_keepdim(0)?.to_vec3::<u32>()?, + &[[[2, 1, 4], [1, 2, 8]]], + ); + let data: Vec<u32> = (200..4000u32).collect(); + let tensor = Tensor::new(data.as_slice(), device)?; + assert_eq!(tensor.min_keepdim(0)?.to_vec1::<u32>()?, &[200]); + let tensor = tensor.reshape((1900, 2))?; + assert_eq!( + tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?, + &[[200]] + ); + assert_eq!( + tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?, + &[[200]] + ); + assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]); + + // Make the tensor non contiguous. + let tensor = tensor.t()?.contiguous()?.t()?; + assert_eq!( + tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?, + &[[200]] + ); + assert_eq!( + tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?, + &[[200]] + ); + assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]); + + let t1 = tensor.reshape((190, 5, 4))?; + let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; + for tensor in [t1, t2] { + assert_eq!( + tensor + .min_keepdim(0)? + .min_keepdim(2)? + .min_keepdim(1)? + .to_vec3::<u32>()?, + &[[[200]]] + ); + assert_eq!( + tensor.min_keepdim(0)?.to_vec3::<u32>()?, + &[[ + [200, 201, 202, 203], + [204, 205, 206, 207], + [208, 209, 210, 211], + [212, 213, 214, 215], + [216, 217, 218, 219] + ]] + ); + } + Ok(()) +} + +fn max(device: &Device) -> Result<()> { + let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.max_keepdim(2)?.to_vec3::<u32>()?, + &[[[4], [9]], [[7], [8]]] + ); + assert_eq!( + tensor.max_keepdim(0)?.to_vec3::<u32>()?, + &[[[3, 1, 7], [8, 5, 9]]], + ); + let data: Vec<u32> = (200..4000u32).collect(); + let tensor = Tensor::new(data.as_slice(), device)?; + assert_eq!(tensor.max_keepdim(0)?.to_vec1::<u32>()?, &[3999]); + let tensor = tensor.reshape((1900, 2))?; + assert_eq!( + tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?, + &[[3999]] + ); + assert_eq!( + tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?, + &[[3999]] + ); + assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]); + + // Make the tensor non contiguous. + let tensor = tensor.t()?.contiguous()?.t()?; + assert_eq!( + tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?, + &[[3999]] + ); + assert_eq!( + tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?, + &[[3999]] + ); + assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]); + + let t1 = tensor.reshape((190, 5, 4))?; + let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; + for tensor in [t1, t2] { + assert_eq!( + tensor + .max_keepdim(0)? + .max_keepdim(2)? + .max_keepdim(1)? + .to_vec3::<u32>()?, + &[[[3999]]] + ); + assert_eq!( + tensor.max_keepdim(0)?.to_vec3::<u32>()?, + &[[ + [3980, 3981, 3982, 3983], + [3984, 3985, 3986, 3987], + [3988, 3989, 3990, 3991], + [3992, 3993, 3994, 3995], + [3996, 3997, 3998, 3999] + ]] + ); + } + Ok(()) +} + +fn argmin(device: &Device) -> Result<()> { + let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.argmin_keepdim(2)?.to_vec3::<u32>()?, + &[[[1], [0]], [[1], [1]]] + ); + assert_eq!( + tensor.argmin_keepdim(0)?.to_vec3::<u32>()?, + &[[[1, 0, 0], [0, 1, 1]]], + ); + let data: Vec<u32> = (200..4000u32).collect(); + let tensor = Tensor::new(data.as_slice(), device)?; + assert_eq!(tensor.argmin_keepdim(0)?.to_vec1::<u32>()?, &[0]); + let tensor = tensor.reshape((1900, 2))?; + assert_eq!( + tensor + .argmin_keepdim(0)? + .argmin_keepdim(1)? + .to_vec2::<u32>()?, + &[[0]] + ); + assert_eq!( + tensor + .argmin_keepdim(1)? + .argmin_keepdim(0)? + .to_vec2::<u32>()?, + &[[0]] + ); + assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]); + + // Make the tensor non contiguous. + let tensor = tensor.t()?.contiguous()?.t()?; + assert_eq!( + tensor + .argmin_keepdim(0)? + .argmin_keepdim(1)? + .to_vec2::<u32>()?, + &[[0]] + ); + assert_eq!( + tensor + .argmin_keepdim(1)? + .argmin_keepdim(0)? + .to_vec2::<u32>()?, + &[[0]] + ); + assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]); + + let t1 = tensor.reshape((190, 5, 4))?; + let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; + for tensor in [t1, t2] { + assert_eq!( + tensor + .argmin_keepdim(0)? + .argmin_keepdim(2)? + .argmin_keepdim(1)? + .to_vec3::<u32>()?, + &[[[0]]] + ); + assert_eq!( + tensor.argmin_keepdim(0)?.to_vec3::<u32>()?, + &[[ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + ]] + ); + } + Ok(()) +} + +fn argmax(device: &Device) -> Result<()> { + let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.argmax_keepdim(2)?.to_vec3::<u32>()?, + &[[[2], [2]], [[2], [0]]] + ); + assert_eq!( + tensor.argmax_keepdim(0)?.to_vec3::<u32>()?, + &[[[0, 0, 1], [1, 0, 0]]], + ); + let data: Vec<u32> = (200..4000u32).collect(); + let tensor = Tensor::new(data.as_slice(), device)?; + assert_eq!(tensor.argmax_keepdim(0)?.to_vec1::<u32>()?, &[3799]); + let tensor = tensor.reshape((1900, 2))?; + assert_eq!( + tensor + .argmax_keepdim(0)? + .argmax_keepdim(1)? + .to_vec2::<u32>()?, + &[[0]] + ); + assert_eq!( + tensor + .argmax_keepdim(1)? + .argmax_keepdim(0)? + .to_vec2::<u32>()?, + &[[0]] + ); + assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]); + + // Make the tensor non contiguous. + let tensor = tensor.t()?.contiguous()?.t()?; + assert_eq!( + tensor + .argmax_keepdim(0)? + .argmax_keepdim(1)? + .to_vec2::<u32>()?, + &[[0]] + ); + assert_eq!( + tensor + .argmax_keepdim(1)? + .argmax_keepdim(0)? + .to_vec2::<u32>()?, + &[[0]] + ); + assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]); + + let t1 = tensor.reshape((190, 5, 4))?; + let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; + for tensor in [t1, t2] { + assert_eq!( + tensor + .argmax_keepdim(0)? + .argmax_keepdim(2)? + .argmax_keepdim(1)? + .to_vec3::<u32>()?, + &[[[0]]] + ); + assert_eq!( + tensor.argmax_keepdim(0)?.to_vec3::<u32>()?, + &[[ + [189, 189, 189, 189], + [189, 189, 189, 189], + [189, 189, 189, 189], + [189, 189, 189, 189], + [189, 189, 189, 189], + ]] + ); + } + Ok(()) +} + fn narrow(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; let tensor = Tensor::new(data, device)?; @@ -581,6 +853,10 @@ test_device!(narrow, narrow_cpu, narrow_gpu); test_device!(broadcast, broadcast_cpu, broadcast_gpu); test_device!(cat, cat_cpu, cat_gpu); test_device!(sum, sum_cpu, sum_gpu); +test_device!(min, min_cpu, min_gpu); +test_device!(max, max_cpu, max_gpu); +test_device!(argmax, argmax_cpu, argmax_gpu); +test_device!(argmin, argmin_cpu, argmin_gpu); test_device!(transpose, transpose_cpu, transpose_gpu); test_device!(binary_op, binary_op_cpu, binary_op_gpu); test_device!(embeddings, embeddings_cpu, embeddings_gpu); diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index f15aa60c..b78d937b 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -142,17 +142,20 @@ fn training_loop<M: Model>( let dev = candle::Device::cuda_if_available(0)?; let train_labels = m.train_labels; - let train_images = m.train_images; - let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?; + let train_images = m.train_images.to_device(&dev)?; + let train_labels = train_labels + .to_dtype(DType::U32)? + .unsqueeze(1)? + .to_device(&dev)?; - let vs = VarStore::new(DType::F32, dev); + let vs = VarStore::new(DType::F32, dev.clone()); let model = M::new(vs.clone())?; let all_vars = vs.all_vars(); let all_vars = all_vars.iter().collect::<Vec<_>>(); let sgd = candle_nn::SGD::new(&all_vars, learning_rate); - let test_images = m.test_images; - let test_labels = m.test_labels.to_dtype(DType::U32)?; + let test_images = m.test_images.to_device(&dev)?; + let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?; for epoch in 1..200 { let logits = model.forward(&train_images)?; let log_sm = ops::log_softmax(&logits, D::Minus1)?; diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index fe3acc9e..ffdf4026 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -144,7 +144,8 @@ __device__ __forceinline__ double copysigng(double a, double b) { return copysig __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); } __device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); } - +__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); } +__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); } #if __CUDA_ARCH__ >= 530 __device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); } __device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); } diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 39a09069..9d4fc710 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -125,7 +125,116 @@ fast_min(const size_t src_numel, const size_t el_to_sum_per_block, dst[dst_id] = shr[0]; } -#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \ +template <typename T> +__device__ void +fast_argmin(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, uint32_t *dst) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + + __shared__ T shr[BLOCK_SIZE]; + __shared__ uint32_t shr_index[BLOCK_SIZE]; + size_t tid = threadIdx.x; + size_t dst_id = blockIdx.x; + + // Not sure how that works on uint32_t and uint8_t but it seems to do ok. + shr[tid] = INFINITY; + shr_index[tid] = 0xFFFFFFFF; + bool not_set = true; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (not_set || src[strided_i] < shr[tid]) { + shr[tid] = src[strided_i]; + // Assume that the reduction takes place over the last dimension which is contiguous. + shr_index[tid] = idx % dims[num_dims - 1]; + not_set = false; + } + idx += blockDim.x; + } + + // Parallel reduction, see the slides: + // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf + // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + __syncthreads(); + if (tid < s && shr[tid + s] < shr[tid]) { + shr[tid] = shr[tid + s]; + shr_index[tid] = shr_index[tid + s]; + } + } + + if (tid == 0) + dst[dst_id] = shr_index[0]; +} + +template <typename T> +__device__ void +fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, uint32_t *dst) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + + __shared__ T shr[BLOCK_SIZE]; + __shared__ uint32_t shr_index[BLOCK_SIZE]; + size_t tid = threadIdx.x; + size_t dst_id = blockIdx.x; + + shr[tid] = -INFINITY; + shr_index[tid] = 0xFFFFFFFF; + bool not_set = true; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (not_set || src[strided_i] > shr[tid]) { + shr[tid] = src[strided_i]; + // Assume that the reduction takes place over the last dimension which is contiguous. + shr_index[tid] = idx % dims[num_dims - 1]; + not_set = false; + } + idx += blockDim.x; + } + + // Parallel reduction, see the slides: + // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf + // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + __syncthreads(); + if (tid < s && shr[tid + s] > shr[tid]) { + shr[tid] = shr[tid + s]; + shr_index[tid] = shr_index[tid + s]; + } + } + + if (tid == 0) + dst[dst_id] = shr_index[0]; +} + +#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, ARGMIN_NAME, ARGMAX_NAME, SUM_NAME) \ + extern "C" __global__ void ARGMIN_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + uint32_t *dst) { \ + fast_argmin(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } \ + extern "C" __global__ void ARGMAX_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + uint32_t *dst) { \ + fast_argmax(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } \ extern "C" __global__ void MIN_NAME( \ const size_t src_numel, const size_t el_to_sum_per_block, \ const size_t num_dims, const size_t *info, const TYPENAME *src, \ @@ -183,18 +292,19 @@ fast_min(const size_t src_numel, const size_t el_to_sum_per_block, #if __CUDA_ARCH__ >= 800 SUM_OP(__nv_bfloat16, sum_bf16) -FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_sum_bf16) +FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) #endif #if __CUDA_ARCH__ >= 530 SUM_OP(__half, sum_f16) -FAST_OP(__half, fast_min_f16, fast_max_f16, fast_sum_f16) +FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) #endif SUM_OP(float, sum_f32) SUM_OP(double, sum_f64) SUM_OP(uint32_t, sum_u32) -FAST_OP(float, fast_min_f32, fast_max_f32, fast_sum_f32) -FAST_OP(double, fast_min_f64, fast_max_f64, fast_sum_f64) -FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_sum_u32) +FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) +FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) +FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32) +FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8) |