summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-23 17:00:00 +0200
committerGitHub <noreply@github.com>2023-07-23 16:00:00 +0100
commit23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0 (patch)
tree04404a97a114126cd5faaaeb97a486f9cdb7b920
parente449ce53a2f3c85f23ca0f2e7d557a0d0003e0ca (diff)
downloadcandle-23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0.tar.gz
candle-23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0.tar.bz2
candle-23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0.zip
Cleanup some todos. (#226)
* Cleanup some todos. * Fix more todo. * Optimize for the contiguous case. * Add the IntDType trait. * Handle the intdtype trait for more ops. * Remove a todo. * Remove a todo.
-rw-r--r--candle-core/src/cpu_backend.rs182
-rw-r--r--candle-core/src/dtype.rs23
-rw-r--r--candle-core/src/lib.rs2
-rw-r--r--candle-core/src/storage.rs1
-rw-r--r--candle-core/tests/tensor_tests.rs7
-rw-r--r--candle-kernels/src/reduce.cu192
6 files changed, 232 insertions, 175 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 82e1f3e2..9a6320ec 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1,6 +1,6 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
-use crate::{DType, Error, Layout, Result, Shape, WithDType};
+use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
@@ -133,9 +133,9 @@ impl Map2U8 for Cmp {
}
}
-struct WCond<'a>(&'a [u32], &'a Layout);
+struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
-impl<'a> Map2 for WCond<'a> {
+impl<'a, I: IntDType> Map2 for WCond<'a, I> {
const OP: &'static str = "where";
#[inline(always)]
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
@@ -150,14 +150,20 @@ impl<'a> Map2 for WCond<'a> {
let f = &f[o_f1..o_f2];
pred.iter()
.zip(t.iter().zip(f.iter()))
- .map(|(&p, (&t, &f))| if p > 0 { t } else { f })
+ .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
.collect::<Vec<_>>()
}
_ => self
.1
.strided_index()
.zip(t_l.strided_index().zip(f_l.strided_index()))
- .map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] })
+ .map(|(i_p, (i_t, i_f))| {
+ if self.0[i_p].is_true() {
+ t[i_t]
+ } else {
+ f[i_f]
+ }
+ })
.collect::<Vec<_>>(),
};
Ok(vs)
@@ -628,13 +634,13 @@ impl Map1 for Affine {
}
}
-struct Gather<'a> {
- ids: &'a [u32],
+struct Gather<'a, I: IntDType> {
+ ids: &'a [I],
ids_l: &'a Layout,
dim: usize,
}
-impl<'a> Map1 for Gather<'a> {
+impl<'a, I: IntDType> Map1 for Gather<'a, I> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
@@ -663,7 +669,7 @@ impl<'a> Map1 for Gather<'a> {
let start_dst_idx = start_dst_idx + i * dst_right_len;
for right_i in 0..dst_right_len {
let dst_idx = start_dst_idx + right_i;
- let index = ids[dst_idx] as usize;
+ let index = ids[dst_idx].as_usize();
if index >= src_dim_len {
Err(Error::InvalidIndex {
index,
@@ -681,13 +687,13 @@ impl<'a> Map1 for Gather<'a> {
}
}
-struct IndexSelect<'a> {
- ids: &'a [u32],
+struct IndexSelect<'a, T: IntDType> {
+ ids: &'a [T],
ids_l: &'a Layout,
dim: usize,
}
-impl<'a> Map1 for IndexSelect<'a> {
+impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
@@ -714,7 +720,7 @@ impl<'a> Map1 for IndexSelect<'a> {
let start_src_idx = left_i * right_len * src_dim;
let start_dst_idx = left_i * right_len * n_ids;
for i in 0..n_ids {
- let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize;
+ let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
if index >= src_dim {
Err(Error::InvalidIndex {
index,
@@ -733,13 +739,13 @@ impl<'a> Map1 for IndexSelect<'a> {
}
}
-struct ScatterAdd<'a> {
- ids: &'a [u32],
+struct ScatterAdd<'a, I: IntDType> {
+ ids: &'a [I],
ids_l: &'a Layout,
dim: usize,
}
-impl<'a> Map2 for ScatterAdd<'a> {
+impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
const OP: &'static str = "scatter-add";
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let dst_len = l1.shape().elem_count();
@@ -771,7 +777,7 @@ impl<'a> Map2 for ScatterAdd<'a> {
let start_ids_idx = start_ids_idx + i * ids_right_len;
for right_i in 0..dst_right_len {
let ids_idx = start_ids_idx + right_i;
- let index = ids[ids_idx] as usize;
+ let index = ids[ids_idx].as_usize();
if index >= dst_dim_len {
Err(Error::InvalidIndex {
index,
@@ -790,12 +796,12 @@ impl<'a> Map2 for ScatterAdd<'a> {
}
}
-struct IndexAdd<'a> {
- ids: &'a [u32],
+struct IndexAdd<'a, I: IntDType> {
+ ids: &'a [I],
dim: usize,
}
-impl<'a> Map2 for IndexAdd<'a> {
+impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
const OP: &'static str = "index-add";
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
// v1, l1 -> self
@@ -811,8 +817,8 @@ impl<'a> Map2 for IndexAdd<'a> {
let max_idx = l1.dims()[dim];
let stride = src_l.stride()[dim];
if dim == 0 {
- for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
- let dst_idx = dst_idx as usize;
+ for (src_idx, dst_idx) in self.ids.iter().enumerate() {
+ let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
index: dst_idx,
@@ -831,8 +837,8 @@ impl<'a> Map2 for IndexAdd<'a> {
} else {
let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
- for (src_idx, &dst_idx) in self.ids.iter().enumerate() {
- let dst_idx = dst_idx as usize;
+ for (src_idx, dst_idx) in self.ids.iter().enumerate() {
+ let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
index: dst_idx,
@@ -856,31 +862,52 @@ impl<'a> Map2 for IndexAdd<'a> {
}
}
-struct Embedding<'a> {
+struct Embedding<'a, I: IntDType> {
vocab_size: usize,
hidden_size: usize,
- ids: &'a [u32],
+ ids: &'a [I],
ids_l: &'a Layout,
}
-impl<'a> Map1 for Embedding<'a> {
+impl<'a, I: IntDType> Map1 for Embedding<'a, I> {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
- // TODO: We assume that vs is contiguous here.
+ if !layout.is_contiguous() {
+ Err(Error::RequiresContiguous { op: "embedding" })?
+ }
let vs = &vs[layout.start_offset()..];
let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
- // TODO: Optimize for the case where ids are contiguous.
- for index in self.ids_l.strided_index() {
- let index = self.ids[index].try_into()?;
- if index >= self.vocab_size {
- Err(Error::InvalidIndex {
- index,
- size: self.vocab_size,
- op: "take",
+ match self.ids_l.contiguous_offsets() {
+ Some((o1, o2)) => {
+ for index in self.ids[o1..o2].iter() {
+ let index = index.as_usize();
+ if index >= self.vocab_size {
+ Err(Error::InvalidIndex {
+ index,
+ size: self.vocab_size,
+ op: "take",
+ }
+ .bt())?
+ } else {
+ let hidden_size = self.hidden_size;
+ values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
+ }
+ }
+ }
+ None => {
+ for index in self.ids_l.strided_index() {
+ let index = self.ids[index].as_usize();
+ if index >= self.vocab_size {
+ Err(Error::InvalidIndex {
+ index,
+ size: self.vocab_size,
+ op: "take",
+ }
+ .bt())?
+ } else {
+ let hidden_size = self.hidden_size;
+ values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
+ }
}
- .bt())?
- } else {
- let hidden_size = self.hidden_size;
- values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(values)
@@ -1671,9 +1698,11 @@ impl BackendStorage for CpuStorage {
f: &Self,
f_l: &Layout,
) -> Result<Self> {
- // TODO: Support types that could be casted to a boolean.
- let pred = self.as_slice::<u32>()?;
- WCond(pred, layout).map(t, t_l, f, f_l)
+ match self {
+ Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
+ Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
+ _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
+ }
}
fn conv1d(
@@ -1687,25 +1716,40 @@ impl BackendStorage for CpuStorage {
}
fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
- let ids = self.as_slice::<u32>()?;
let (vocab_size, hidden_size) = rhs_l.shape().dims2()?;
- Embedding {
- vocab_size,
- hidden_size,
- ids,
- ids_l,
+ match self {
+ Self::U8(ids) => Embedding {
+ vocab_size,
+ hidden_size,
+ ids,
+ ids_l,
+ }
+ .map(rhs, rhs_l),
+ Self::U32(ids) => Embedding {
+ vocab_size,
+ hidden_size,
+ ids,
+ ids_l,
+ }
+ .map(rhs, rhs_l),
+ _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "embedding")),
}
- .map(rhs, rhs_l)
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
- let ids = ids.as_slice::<u32>()?;
- IndexSelect { ids, ids_l, dim }.map(self, l)
+ match ids {
+ Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
+ Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
+ _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
+ }
}
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
- let ids = ids.as_slice::<u32>()?;
- Gather { ids, ids_l, dim }.map(self, l)
+ match ids {
+ Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
+ Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
+ _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
+ }
}
fn scatter_add(
@@ -1717,8 +1761,11 @@ impl BackendStorage for CpuStorage {
src_l: &Layout,
dim: usize,
) -> Result<Self> {
- let ids = ids.as_slice::<u32>()?;
- ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l)
+ match ids {
+ Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
+ Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
+ _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
+ }
}
fn index_add(
@@ -1730,12 +1777,23 @@ impl BackendStorage for CpuStorage {
src_l: &Layout,
dim: usize,
) -> Result<Self> {
- let ids = ids.as_slice::<u32>()?;
- let ids = match ids_l.contiguous_offsets() {
- Some((a, b)) => &ids[a..b],
- None => Err(Error::RequiresContiguous { op: "index-add" })?,
- };
- IndexAdd { ids, dim }.map(self, l, src, src_l)
+ match ids {
+ Self::U8(ids) => {
+ let ids = match ids_l.contiguous_offsets() {
+ Some((a, b)) => &ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "index-add" })?,
+ };
+ IndexAdd { ids, dim }.map(self, l, src, src_l)
+ }
+ Self::U32(ids) => {
+ let ids = match ids_l.contiguous_offsets() {
+ Some((a, b)) => &ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "index-add" })?,
+ };
+ IndexAdd { ids, dim }.map(self, l, src, src_l)
+ }
+ _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
+ }
}
fn matmul(
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index 59802c04..c6befbb8 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -119,3 +119,26 @@ with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
+
+pub trait IntDType {
+ fn is_true(&self) -> bool;
+ fn as_usize(&self) -> usize;
+}
+
+impl IntDType for u32 {
+ fn is_true(&self) -> bool {
+ *self != 0
+ }
+ fn as_usize(&self) -> usize {
+ *self as usize
+ }
+}
+
+impl IntDType for u8 {
+ fn is_true(&self) -> bool {
+ *self != 0
+ }
+ fn as_usize(&self) -> usize {
+ *self as usize
+ }
+}
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 5a35955f..787ea63a 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -61,7 +61,7 @@ mod variable;
pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation};
-pub use dtype::{DType, WithDType};
+pub use dtype::{DType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
pub use layout::Layout;
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 5e6cfdf2..52af5861 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -206,7 +206,6 @@ impl Storage {
}
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
- // TODO: Different code path for the contiguous case?
match self {
Storage::Cpu(storage) => {
let storage = storage.unary_impl::<B>(layout)?;
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index a126d634..95ce982a 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -270,7 +270,11 @@ fn cat(device: &Device) -> Result<()> {
[2.0, 7.0, 1.0, 8.0, 2.0]
]
);
- // TODO: This is not the expected answer, to be fixed!
+ // PyTorch equivalent:
+ // import torch
+ // t1 = torch.tensor([[3, 1, 4, 1, 5], [2, 7, 1, 8, 2]])
+ // t2 = torch.tensor([[5]*5, [2, 7, 1, 8, 2]])
+ // torch.cat([t1.t(), t2.t()], dim=1).t()
assert_eq!(
Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
.t()?
@@ -282,7 +286,6 @@ fn cat(device: &Device) -> Result<()> {
[2.0, 7.0, 1.0, 8.0, 2.0]
]
);
- // TODO: This is not the expected answer, to be fixed!
assert_eq!(
Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?,
[
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu
index 34caf12b..39a09069 100644
--- a/candle-kernels/src/reduce.cu
+++ b/candle-kernels/src/reduce.cu
@@ -1,26 +1,20 @@
-// TODO: Use a proper distributed reduction rather than atomicAdd.
-// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf
#include "cuda_utils.cuh"
-#include<stdint.h>
-#include<cmath>
+#include <cmath>
+#include <stdint.h>
const int BLOCK_SIZE = 1024;
-// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 but
-// also expect a f32 output so that this can be used for normalization e.g. in softmax.
+// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
+// but also expect a f32 output so that this can be used for normalization e.g.
+// in softmax.
// Fast reduce sum kernel, this assumes that the dimensions to loop over are at
-// the end, each block is responsible for populating one value in the output array.
-// There are at most 1024 threads per block.
+// the end, each block is responsible for populating one value in the output
+// array. There are at most 1024 threads per block.
template <typename T>
-__device__ void fast_sum(
- const size_t src_numel,
- const size_t el_to_sum_per_block,
- const size_t num_dims,
- const size_t *info,
- const T *src,
- T *dst
-) {
+__device__ void
+fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
+ const size_t num_dims, const size_t *info, const T *src, T *dst) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
@@ -47,21 +41,18 @@ __device__ void fast_sum(
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
__syncthreads();
- if (tid < s) shr[tid] += shr[tid + s];
+ if (tid < s)
+ shr[tid] += shr[tid + s];
}
- if (tid == 0) dst[dst_id] = shr[0];
+ if (tid == 0)
+ dst[dst_id] = shr[0];
}
template <typename T>
-__device__ void fast_max(
- const size_t src_numel,
- const size_t el_to_sum_per_block,
- const size_t num_dims,
- const size_t *info,
- const T *src,
- T *dst
-) {
+__device__ void
+fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
+ const size_t num_dims, const size_t *info, const T *src, T *dst) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
@@ -88,21 +79,18 @@ __device__ void fast_max(
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
__syncthreads();
- if (tid < s) shr[tid] = maxg(shr[tid], shr[tid + s]);
+ if (tid < s)
+ shr[tid] = maxg(shr[tid], shr[tid + s]);
}
- if (tid == 0) dst[dst_id] = shr[0];
+ if (tid == 0)
+ dst[dst_id] = shr[0];
}
template <typename T>
-__device__ void fast_min(
- const size_t src_numel,
- const size_t el_to_sum_per_block,
- const size_t num_dims,
- const size_t *info,
- const T *src,
- T *dst
-) {
+__device__ void
+fast_min(const size_t src_numel, const size_t el_to_sum_per_block,
+ const size_t num_dims, const size_t *info, const T *src, T *dst) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
@@ -129,83 +117,69 @@ __device__ void fast_min(
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
__syncthreads();
- if (tid < s) shr[tid] = ming(shr[tid], shr[tid + s]);
+ if (tid < s)
+ shr[tid] = ming(shr[tid], shr[tid + s]);
}
- if (tid == 0) dst[dst_id] = shr[0];
+ if (tid == 0)
+ dst[dst_id] = shr[0];
}
-#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \
-extern "C" __global__ void MIN_NAME( \
- const size_t src_numel, \
- const size_t el_to_sum_per_block, \
- const size_t num_dims, \
- const size_t *info, \
- const TYPENAME *src, \
- TYPENAME *dst \
-) { \
- fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
-} \
-extern "C" __global__ void MAX_NAME( \
- const size_t src_numel, \
- const size_t el_to_sum_per_block, \
- const size_t num_dims, \
- const size_t *info, \
- const TYPENAME *src, \
- TYPENAME *dst \
-) { \
- fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
-} \
-extern "C" __global__ void SUM_NAME( \
- const size_t src_numel, \
- const size_t el_to_sum_per_block, \
- const size_t num_dims, \
- const size_t *info, \
- const TYPENAME *src, \
- TYPENAME *dst \
-) { \
- fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
-} \
-
-#define SUM_OP(TYPENAME, FN_NAME) \
-extern "C" __global__ void FN_NAME( \
- const size_t numel, \
- const size_t num_dims, \
- const size_t num_sum_dims, \
- const size_t *info, \
- const TYPENAME *inp, \
- TYPENAME *out \
-) { \
- const size_t *dims = info; \
- const size_t *strides = info + num_dims; \
- const size_t *sum_dims_l = info + 2*num_dims; \
- const size_t *sum_dims_s = info + 2*num_dims + num_sum_dims; \
- if (is_contiguous(num_dims, dims, strides)) { \
- for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
- size_t dst_index = i; \
- for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
- size_t stride = sum_dims_s[nd]; \
- size_t pre = dst_index / stride; \
- size_t post = dst_index % stride; \
- dst_index = (pre / sum_dims_l[nd]) * stride + post; \
- } \
- atomicAdd(out + dst_index, inp[i]); \
- } \
- } \
- else { \
- for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
- unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
- size_t dst_index = i; \
- for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
- size_t stride = sum_dims_s[nd]; \
- size_t pre = dst_index / stride; \
- size_t post = dst_index % stride; \
- dst_index = (pre / sum_dims_l[nd]) * stride + post; \
- } \
- atomicAdd(out + dst_index, inp[strided_i]); \
- } \
- } \
-} \
+#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \
+ extern "C" __global__ void MIN_NAME( \
+ const size_t src_numel, const size_t el_to_sum_per_block, \
+ const size_t num_dims, const size_t *info, const TYPENAME *src, \
+ TYPENAME *dst) { \
+ fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
+ } \
+ extern "C" __global__ void MAX_NAME( \
+ const size_t src_numel, const size_t el_to_sum_per_block, \
+ const size_t num_dims, const size_t *info, const TYPENAME *src, \
+ TYPENAME *dst) { \
+ fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
+ } \
+ extern "C" __global__ void SUM_NAME( \
+ const size_t src_numel, const size_t el_to_sum_per_block, \
+ const size_t num_dims, const size_t *info, const TYPENAME *src, \
+ TYPENAME *dst) { \
+ fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \
+ }
+
+#define SUM_OP(TYPENAME, FN_NAME) \
+ extern "C" __global__ void FN_NAME( \
+ const size_t numel, const size_t num_dims, const size_t num_sum_dims, \
+ const size_t *info, const TYPENAME *inp, TYPENAME *out) { \
+ const size_t *dims = info; \
+ const size_t *strides = info + num_dims; \
+ const size_t *sum_dims_l = info + 2 * num_dims; \
+ const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; \
+ if (is_contiguous(num_dims, dims, strides)) { \
+ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \
+ i += blockDim.x * gridDim.x) { \
+ size_t dst_index = i; \
+ for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
+ size_t stride = sum_dims_s[nd]; \
+ size_t pre = dst_index / stride; \
+ size_t post = dst_index % stride; \
+ dst_index = (pre / sum_dims_l[nd]) * stride + post; \
+ } \
+ atomicAdd(out + dst_index, inp[i]); \
+ } \
+ } else { \
+ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \
+ i += blockDim.x * gridDim.x) { \
+ unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
+ size_t dst_index = i; \
+ for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \
+ size_t stride = sum_dims_s[nd]; \
+ size_t pre = dst_index / stride; \
+ size_t post = dst_index % stride; \
+ dst_index = (pre / sum_dims_l[nd]) * stride + post; \
+ } \
+ atomicAdd(out + dst_index, inp[strided_i]); \
+ } \
+ } \
+ }
#if __CUDA_ARCH__ >= 800
SUM_OP(__nv_bfloat16, sum_bf16)