diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-23 17:00:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-23 16:00:00 +0100 |
commit | 23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0 (patch) | |
tree | 04404a97a114126cd5faaaeb97a486f9cdb7b920 | |
parent | e449ce53a2f3c85f23ca0f2e7d557a0d0003e0ca (diff) | |
download | candle-23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0.tar.gz candle-23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0.tar.bz2 candle-23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0.zip |
Cleanup some todos. (#226)
* Cleanup some todos.
* Fix more todo.
* Optimize for the contiguous case.
* Add the IntDType trait.
* Handle the intdtype trait for more ops.
* Remove a todo.
* Remove a todo.
-rw-r--r-- | candle-core/src/cpu_backend.rs | 182 | ||||
-rw-r--r-- | candle-core/src/dtype.rs | 23 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 1 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 7 | ||||
-rw-r--r-- | candle-kernels/src/reduce.cu | 192 |
6 files changed, 232 insertions, 175 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 82e1f3e2..9a6320ec 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,6 +1,6 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; -use crate::{DType, Error, Layout, Result, Shape, WithDType}; +use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + @@ -133,9 +133,9 @@ impl Map2U8 for Cmp { } } -struct WCond<'a>(&'a [u32], &'a Layout); +struct WCond<'a, T: IntDType>(&'a [T], &'a Layout); -impl<'a> Map2 for WCond<'a> { +impl<'a, I: IntDType> Map2 for WCond<'a, I> { const OP: &'static str = "where"; #[inline(always)] fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> { @@ -150,14 +150,20 @@ impl<'a> Map2 for WCond<'a> { let f = &f[o_f1..o_f2]; pred.iter() .zip(t.iter().zip(f.iter())) - .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) + .map(|(p, (&t, &f))| if p.is_true() { t } else { f }) .collect::<Vec<_>>() } _ => self .1 .strided_index() .zip(t_l.strided_index().zip(f_l.strided_index())) - .map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] }) + .map(|(i_p, (i_t, i_f))| { + if self.0[i_p].is_true() { + t[i_t] + } else { + f[i_f] + } + }) .collect::<Vec<_>>(), }; Ok(vs) @@ -628,13 +634,13 @@ impl Map1 for Affine { } } -struct Gather<'a> { - ids: &'a [u32], +struct Gather<'a, I: IntDType> { + ids: &'a [I], ids_l: &'a Layout, dim: usize, } -impl<'a> Map1 for Gather<'a> { +impl<'a, I: IntDType> Map1 for Gather<'a, I> { fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> { let ids = match self.ids_l.contiguous_offsets() { Some((a, b)) => &self.ids[a..b], @@ -663,7 +669,7 @@ impl<'a> Map1 for Gather<'a> { let start_dst_idx = start_dst_idx + i * dst_right_len; for right_i in 0..dst_right_len { let dst_idx = start_dst_idx + right_i; - let index = ids[dst_idx] as usize; + let index = ids[dst_idx].as_usize(); if index >= src_dim_len { Err(Error::InvalidIndex { index, @@ -681,13 +687,13 @@ impl<'a> Map1 for Gather<'a> { } } -struct IndexSelect<'a> { - ids: &'a [u32], +struct IndexSelect<'a, T: IntDType> { + ids: &'a [T], ids_l: &'a Layout, dim: usize, } -impl<'a> Map1 for IndexSelect<'a> { +impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> { fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> { let src = match layout.contiguous_offsets() { Some((a, b)) => &src[a..b], @@ -714,7 +720,7 @@ impl<'a> Map1 for IndexSelect<'a> { let start_src_idx = left_i * right_len * src_dim; let start_dst_idx = left_i * right_len * n_ids; for i in 0..n_ids { - let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize; + let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize(); if index >= src_dim { Err(Error::InvalidIndex { index, @@ -733,13 +739,13 @@ impl<'a> Map1 for IndexSelect<'a> { } } -struct ScatterAdd<'a> { - ids: &'a [u32], +struct ScatterAdd<'a, I: IntDType> { + ids: &'a [I], ids_l: &'a Layout, dim: usize, } -impl<'a> Map2 for ScatterAdd<'a> { +impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> { const OP: &'static str = "scatter-add"; fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> { let dst_len = l1.shape().elem_count(); @@ -771,7 +777,7 @@ impl<'a> Map2 for ScatterAdd<'a> { let start_ids_idx = start_ids_idx + i * ids_right_len; for right_i in 0..dst_right_len { let ids_idx = start_ids_idx + right_i; - let index = ids[ids_idx] as usize; + let index = ids[ids_idx].as_usize(); if index >= dst_dim_len { Err(Error::InvalidIndex { index, @@ -790,12 +796,12 @@ impl<'a> Map2 for ScatterAdd<'a> { } } -struct IndexAdd<'a> { - ids: &'a [u32], +struct IndexAdd<'a, I: IntDType> { + ids: &'a [I], dim: usize, } -impl<'a> Map2 for IndexAdd<'a> { +impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { const OP: &'static str = "index-add"; // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_ // v1, l1 -> self @@ -811,8 +817,8 @@ impl<'a> Map2 for IndexAdd<'a> { let max_idx = l1.dims()[dim]; let stride = src_l.stride()[dim]; if dim == 0 { - for (src_idx, &dst_idx) in self.ids.iter().enumerate() { - let dst_idx = dst_idx as usize; + for (src_idx, dst_idx) in self.ids.iter().enumerate() { + let dst_idx = dst_idx.as_usize(); if dst_idx >= max_idx { Err(Error::InvalidIndex { index: dst_idx, @@ -831,8 +837,8 @@ impl<'a> Map2 for IndexAdd<'a> { } else { let pre_dim = src_l.dims()[..dim].iter().product::<usize>(); let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>(); - for (src_idx, &dst_idx) in self.ids.iter().enumerate() { - let dst_idx = dst_idx as usize; + for (src_idx, dst_idx) in self.ids.iter().enumerate() { + let dst_idx = dst_idx.as_usize(); if dst_idx >= max_idx { Err(Error::InvalidIndex { index: dst_idx, @@ -856,31 +862,52 @@ impl<'a> Map2 for IndexAdd<'a> { } } -struct Embedding<'a> { +struct Embedding<'a, I: IntDType> { vocab_size: usize, hidden_size: usize, - ids: &'a [u32], + ids: &'a [I], ids_l: &'a Layout, } -impl<'a> Map1 for Embedding<'a> { +impl<'a, I: IntDType> Map1 for Embedding<'a, I> { fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> { - // TODO: We assume that vs is contiguous here. + if !layout.is_contiguous() { + Err(Error::RequiresContiguous { op: "embedding" })? + } let vs = &vs[layout.start_offset()..]; let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size); - // TODO: Optimize for the case where ids are contiguous. - for index in self.ids_l.strided_index() { - let index = self.ids[index].try_into()?; - if index >= self.vocab_size { - Err(Error::InvalidIndex { - index, - size: self.vocab_size, - op: "take", + match self.ids_l.contiguous_offsets() { + Some((o1, o2)) => { + for index in self.ids[o1..o2].iter() { + let index = index.as_usize(); + if index >= self.vocab_size { + Err(Error::InvalidIndex { + index, + size: self.vocab_size, + op: "take", + } + .bt())? + } else { + let hidden_size = self.hidden_size; + values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); + } + } + } + None => { + for index in self.ids_l.strided_index() { + let index = self.ids[index].as_usize(); + if index >= self.vocab_size { + Err(Error::InvalidIndex { + index, + size: self.vocab_size, + op: "take", + } + .bt())? + } else { + let hidden_size = self.hidden_size; + values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); + } } - .bt())? - } else { - let hidden_size = self.hidden_size; - values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); } } Ok(values) @@ -1671,9 +1698,11 @@ impl BackendStorage for CpuStorage { f: &Self, f_l: &Layout, ) -> Result<Self> { - // TODO: Support types that could be casted to a boolean. - let pred = self.as_slice::<u32>()?; - WCond(pred, layout).map(t, t_l, f, f_l) + match self { + Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), + } } fn conv1d( @@ -1687,25 +1716,40 @@ impl BackendStorage for CpuStorage { } fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { - let ids = self.as_slice::<u32>()?; let (vocab_size, hidden_size) = rhs_l.shape().dims2()?; - Embedding { - vocab_size, - hidden_size, - ids, - ids_l, + match self { + Self::U8(ids) => Embedding { + vocab_size, + hidden_size, + ids, + ids_l, + } + .map(rhs, rhs_l), + Self::U32(ids) => Embedding { + vocab_size, + hidden_size, + ids, + ids_l, + } + .map(rhs, rhs_l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "embedding")), } - .map(rhs, rhs_l) } fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { - let ids = ids.as_slice::<u32>()?; - IndexSelect { ids, ids_l, dim }.map(self, l) + match ids { + Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")), + } } fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> { - let ids = ids.as_slice::<u32>()?; - Gather { ids, ids_l, dim }.map(self, l) + match ids { + Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")), + } } fn scatter_add( @@ -1717,8 +1761,11 @@ impl BackendStorage for CpuStorage { src_l: &Layout, dim: usize, ) -> Result<Self> { - let ids = ids.as_slice::<u32>()?; - ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l) + match ids { + Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")), + } } fn index_add( @@ -1730,12 +1777,23 @@ impl BackendStorage for CpuStorage { src_l: &Layout, dim: usize, ) -> Result<Self> { - let ids = ids.as_slice::<u32>()?; - let ids = match ids_l.contiguous_offsets() { - Some((a, b)) => &ids[a..b], - None => Err(Error::RequiresContiguous { op: "index-add" })?, - }; - IndexAdd { ids, dim }.map(self, l, src, src_l) + match ids { + Self::U8(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" })?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + Self::U32(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" })?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")), + } } fn matmul( diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 59802c04..c6befbb8 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -119,3 +119,26 @@ with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); with_dtype!(f64, F64, |v: f64| v, |v: f64| v); + +pub trait IntDType { + fn is_true(&self) -> bool; + fn as_usize(&self) -> usize; +} + +impl IntDType for u32 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + +impl IntDType for u8 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 5a35955f..787ea63a 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -61,7 +61,7 @@ mod variable; pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; -pub use dtype::{DType, WithDType}; +pub use dtype::{DType, IntDType, WithDType}; pub use error::{Error, Result}; pub use indexer::IndexOp; pub use layout::Layout; diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 5e6cfdf2..52af5861 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -206,7 +206,6 @@ impl Storage { } pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> { - // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { let storage = storage.unary_impl::<B>(layout)?; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index a126d634..95ce982a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -270,7 +270,11 @@ fn cat(device: &Device) -> Result<()> { [2.0, 7.0, 1.0, 8.0, 2.0] ] ); - // TODO: This is not the expected answer, to be fixed! + // PyTorch equivalent: + // import torch + // t1 = torch.tensor([[3, 1, 4, 1, 5], [2, 7, 1, 8, 2]]) + // t2 = torch.tensor([[5]*5, [2, 7, 1, 8, 2]]) + // torch.cat([t1.t(), t2.t()], dim=1).t() assert_eq!( Tensor::cat(&[&t1.t()?, &t2.t()?], 1)? .t()? @@ -282,7 +286,6 @@ fn cat(device: &Device) -> Result<()> { [2.0, 7.0, 1.0, 8.0, 2.0] ] ); - // TODO: This is not the expected answer, to be fixed! assert_eq!( Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?, [ diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 34caf12b..39a09069 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -1,26 +1,20 @@ -// TODO: Use a proper distributed reduction rather than atomicAdd. -// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf #include "cuda_utils.cuh" -#include<stdint.h> -#include<cmath> +#include <cmath> +#include <stdint.h> const int BLOCK_SIZE = 1024; -// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 but -// also expect a f32 output so that this can be used for normalization e.g. in softmax. +// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 +// but also expect a f32 output so that this can be used for normalization e.g. +// in softmax. // Fast reduce sum kernel, this assumes that the dimensions to loop over are at -// the end, each block is responsible for populating one value in the output array. -// There are at most 1024 threads per block. +// the end, each block is responsible for populating one value in the output +// array. There are at most 1024 threads per block. template <typename T> -__device__ void fast_sum( - const size_t src_numel, - const size_t el_to_sum_per_block, - const size_t num_dims, - const size_t *info, - const T *src, - T *dst -) { +__device__ void +fast_sum(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, T *dst) { const size_t *dims = info; const size_t *strides = info + num_dims; @@ -47,21 +41,18 @@ __device__ void fast_sum( // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce for (int s = blockDim.x / 2; s > 0; s >>= 1) { __syncthreads(); - if (tid < s) shr[tid] += shr[tid + s]; + if (tid < s) + shr[tid] += shr[tid + s]; } - if (tid == 0) dst[dst_id] = shr[0]; + if (tid == 0) + dst[dst_id] = shr[0]; } template <typename T> -__device__ void fast_max( - const size_t src_numel, - const size_t el_to_sum_per_block, - const size_t num_dims, - const size_t *info, - const T *src, - T *dst -) { +__device__ void +fast_max(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, T *dst) { const size_t *dims = info; const size_t *strides = info + num_dims; @@ -88,21 +79,18 @@ __device__ void fast_max( // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce for (int s = blockDim.x / 2; s > 0; s >>= 1) { __syncthreads(); - if (tid < s) shr[tid] = maxg(shr[tid], shr[tid + s]); + if (tid < s) + shr[tid] = maxg(shr[tid], shr[tid + s]); } - if (tid == 0) dst[dst_id] = shr[0]; + if (tid == 0) + dst[dst_id] = shr[0]; } template <typename T> -__device__ void fast_min( - const size_t src_numel, - const size_t el_to_sum_per_block, - const size_t num_dims, - const size_t *info, - const T *src, - T *dst -) { +__device__ void +fast_min(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, T *dst) { const size_t *dims = info; const size_t *strides = info + num_dims; @@ -129,83 +117,69 @@ __device__ void fast_min( // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce for (int s = blockDim.x / 2; s > 0; s >>= 1) { __syncthreads(); - if (tid < s) shr[tid] = ming(shr[tid], shr[tid + s]); + if (tid < s) + shr[tid] = ming(shr[tid], shr[tid + s]); } - if (tid == 0) dst[dst_id] = shr[0]; + if (tid == 0) + dst[dst_id] = shr[0]; } -#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \ -extern "C" __global__ void MIN_NAME( \ - const size_t src_numel, \ - const size_t el_to_sum_per_block, \ - const size_t num_dims, \ - const size_t *info, \ - const TYPENAME *src, \ - TYPENAME *dst \ -) { \ - fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ -} \ -extern "C" __global__ void MAX_NAME( \ - const size_t src_numel, \ - const size_t el_to_sum_per_block, \ - const size_t num_dims, \ - const size_t *info, \ - const TYPENAME *src, \ - TYPENAME *dst \ -) { \ - fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ -} \ -extern "C" __global__ void SUM_NAME( \ - const size_t src_numel, \ - const size_t el_to_sum_per_block, \ - const size_t num_dims, \ - const size_t *info, \ - const TYPENAME *src, \ - TYPENAME *dst \ -) { \ - fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ -} \ - -#define SUM_OP(TYPENAME, FN_NAME) \ -extern "C" __global__ void FN_NAME( \ - const size_t numel, \ - const size_t num_dims, \ - const size_t num_sum_dims, \ - const size_t *info, \ - const TYPENAME *inp, \ - TYPENAME *out \ -) { \ - const size_t *dims = info; \ - const size_t *strides = info + num_dims; \ - const size_t *sum_dims_l = info + 2*num_dims; \ - const size_t *sum_dims_s = info + 2*num_dims + num_sum_dims; \ - if (is_contiguous(num_dims, dims, strides)) { \ - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ - size_t dst_index = i; \ - for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ - size_t stride = sum_dims_s[nd]; \ - size_t pre = dst_index / stride; \ - size_t post = dst_index % stride; \ - dst_index = (pre / sum_dims_l[nd]) * stride + post; \ - } \ - atomicAdd(out + dst_index, inp[i]); \ - } \ - } \ - else { \ - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ - unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ - size_t dst_index = i; \ - for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ - size_t stride = sum_dims_s[nd]; \ - size_t pre = dst_index / stride; \ - size_t post = dst_index % stride; \ - dst_index = (pre / sum_dims_l[nd]) * stride + post; \ - } \ - atomicAdd(out + dst_index, inp[strided_i]); \ - } \ - } \ -} \ +#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \ + extern "C" __global__ void MIN_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + TYPENAME *dst) { \ + fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } \ + extern "C" __global__ void MAX_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + TYPENAME *dst) { \ + fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } \ + extern "C" __global__ void SUM_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + TYPENAME *dst) { \ + fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } + +#define SUM_OP(TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const size_t numel, const size_t num_dims, const size_t num_sum_dims, \ + const size_t *info, const TYPENAME *inp, TYPENAME *out) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + const size_t *sum_dims_l = info + 2 * num_dims; \ + const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; \ + if (is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \ + i += blockDim.x * gridDim.x) { \ + size_t dst_index = i; \ + for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ + size_t stride = sum_dims_s[nd]; \ + size_t pre = dst_index / stride; \ + size_t post = dst_index % stride; \ + dst_index = (pre / sum_dims_l[nd]) * stride + post; \ + } \ + atomicAdd(out + dst_index, inp[i]); \ + } \ + } else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \ + i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + size_t dst_index = i; \ + for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ + size_t stride = sum_dims_s[nd]; \ + size_t pre = dst_index / stride; \ + size_t post = dst_index % stride; \ + dst_index = (pre / sum_dims_l[nd]) * stride + post; \ + } \ + atomicAdd(out + dst_index, inp[strided_i]); \ + } \ + } \ + } #if __CUDA_ARCH__ >= 800 SUM_OP(__nv_bfloat16, sum_bf16) |