diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-08 12:43:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-08 12:43:56 +0100 |
commit | e676f85f00184e3bb8b878a1de4ed45a2cf883f4 (patch) | |
tree | 5fed6047e24a698d2c8d20f9fb658a8b6ea391aa | |
parent | 33479c5f1b98a6e9f537ea139449bb8dc26fed3e (diff) | |
download | candle-e676f85f00184e3bb8b878a1de4ed45a2cf883f4.tar.gz candle-e676f85f00184e3bb8b878a1de4ed45a2cf883f4.tar.bz2 candle-e676f85f00184e3bb8b878a1de4ed45a2cf883f4.zip |
Sketch a fast cuda kernel for reduce-sum. (#109)
* Sketch a fast cuda kernel for reduce-sum.
* Sketch the rust support code for the fast sum kernel.
* More work on the fast kernel.
* Add some testing ground.
* A couple fixes for the fast sum kernel.
-rw-r--r-- | candle-core/examples/cuda_basics.rs | 15 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 53 | ||||
-rw-r--r-- | candle-kernels/src/reduce.cu | 67 |
3 files changed, 134 insertions, 1 deletions
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs new file mode 100644 index 00000000..aeee541a --- /dev/null +++ b/candle-core/examples/cuda_basics.rs @@ -0,0 +1,15 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::Result; +use candle::{Device, Tensor}; + +fn main() -> Result<()> { + let device = Device::new_cuda(0)?; + let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?; + let sum = t.sum(&[0])?; + println!("{sum}"); + let sum = t.sum(&[1])?; + println!("{sum}"); + Ok(()) +} diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index b1990b8f..543d1280 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -357,6 +357,7 @@ impl Map1 for Affine { } } +#[allow(dead_code)] struct Sum<'a>(&'a [usize]); impl<'a> Map1 for Sum<'a> { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( @@ -393,6 +394,56 @@ impl<'a> Map1 for Sum<'a> { } } +#[allow(dead_code)] +struct FastSum<'a>(&'a [usize]); +impl<'a> Map1 for FastSum<'a> { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + ) -> Result<CudaSlice<T>> { + let src_stride = layout.stride(); + let src_dims = layout.shape().dims(); + let src_el: usize = src_dims.iter().product(); + // Source dims and strides with the sum dims at the end. + let mut dims = vec![]; + let mut stride = vec![]; + let mut dst_el: usize = 1; + for (dim_idx, &d) in src_dims.iter().enumerate() { + if !self.0.contains(&dim_idx) { + dst_el *= d; + dims.push(d); + stride.push(src_stride[dim_idx]); + } + } + for &dim_idx in self.0.iter() { + dims.push(src_dims[dim_idx]); + stride.push(src_stride[dim_idx]); + } + let el_to_sum_per_block = src_el / dst_el; + // The reduction loop requires the shared array to be properly initialized and for + // this we want the number of threads to be a power of two. + let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two(); + let cfg = LaunchConfig { + // TODO: Maybe use grid_y if the output is too large? + // TODO: Specialized implementation when reducing on no or all dimensions or when + // reducing only aggregate a small number of elements together. + grid_dim: (dst_el as u32, 1, 1), + block_dim: (block_dim as u32, 1, 1), + shared_mem_bytes: 0, + }; + let ds = dev.htod_copy([dims.as_slice(), stride.as_slice()].concat())?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::<T>("fast_sum"), kernels::REDUCE)?; + let out = dev.alloc_zeros::<T>(dst_el)?; + let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + impl<U: crate::op::UnaryOp> Map1 for U { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, @@ -726,7 +777,7 @@ impl CudaStorage { pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { let device = self.device().clone(); - let slice = Sum(sum_dims).map(&self.slice, &device, layout)?; + let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index c341fcfb..afe687bf 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -3,6 +3,67 @@ #include "cuda_utils.cuh" #include<stdint.h> +const int BLOCK_SIZE = 1024; + +// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 but +// also expect a f32 output so that this can be used for normalization e.g. in softmax. + +// Fast reduce sum kernel, this assumes that the dimensions to loop over are at +// the end, each block is responsible for populating one value in the output array. +// There are at most 1024 threads per block. +template <typename T> +__device__ void fast_sum( + const size_t src_numel, + const size_t el_to_sum_per_block, + const size_t num_dims, + const size_t *info, + const T *src, + T *dst +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + + __shared__ T shr[BLOCK_SIZE]; + size_t tid = threadIdx.x; + size_t dst_id = blockIdx.x; + + shr[tid] = 0.0; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + shr[tid] += src[strided_i]; + idx += blockDim.x; + } + + // Parallel reduction, see the slides: + // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf + // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + __syncthreads(); + if (tid < s) shr[tid] += shr[tid + s]; + } + + if (tid == 0) atomicAdd(dst + dst_id, shr[0]); +} + +#define FAST_SUM_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t el_to_sum_per_block, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ +} \ + #define SUM_OP(TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ @@ -45,12 +106,18 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 SUM_OP(__nv_bfloat16, sum_bf16) +FAST_SUM_OP(__nv_bfloat16, fast_sum_bf16) #endif #if __CUDA_ARCH__ >= 530 SUM_OP(__half, sum_f16) +FAST_SUM_OP(__half, fast_sum_f16) #endif SUM_OP(float, sum_f32) SUM_OP(double, sum_f64) SUM_OP(uint32_t, sum_u32) + +FAST_SUM_OP(float, fast_sum_f32) +FAST_SUM_OP(double, fast_sum_f64) +FAST_SUM_OP(uint32_t, fast_sum_u32) |