diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-23 17:00:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-23 16:00:00 +0100 |
commit | 23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0 (patch) | |
tree | 04404a97a114126cd5faaaeb97a486f9cdb7b920 /candle-kernels/src | |
parent | e449ce53a2f3c85f23ca0f2e7d557a0d0003e0ca (diff) | |
download | candle-23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0.tar.gz candle-23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0.tar.bz2 candle-23827c49cd6c983ba0dc36c1cbc9cc397f43b2c0.zip |
Cleanup some todos. (#226)
* Cleanup some todos.
* Fix more todo.
* Optimize for the contiguous case.
* Add the IntDType trait.
* Handle the intdtype trait for more ops.
* Remove a todo.
* Remove a todo.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r-- | candle-kernels/src/reduce.cu | 192 |
1 files changed, 83 insertions, 109 deletions
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 34caf12b..39a09069 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -1,26 +1,20 @@ -// TODO: Use a proper distributed reduction rather than atomicAdd. -// https://people.maths.ox.ac.uk/gilesm/cuda/prac4/reduction.pdf #include "cuda_utils.cuh" -#include<stdint.h> -#include<cmath> +#include <cmath> +#include <stdint.h> const int BLOCK_SIZE = 1024; -// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 but -// also expect a f32 output so that this can be used for normalization e.g. in softmax. +// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 +// but also expect a f32 output so that this can be used for normalization e.g. +// in softmax. // Fast reduce sum kernel, this assumes that the dimensions to loop over are at -// the end, each block is responsible for populating one value in the output array. -// There are at most 1024 threads per block. +// the end, each block is responsible for populating one value in the output +// array. There are at most 1024 threads per block. template <typename T> -__device__ void fast_sum( - const size_t src_numel, - const size_t el_to_sum_per_block, - const size_t num_dims, - const size_t *info, - const T *src, - T *dst -) { +__device__ void +fast_sum(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, T *dst) { const size_t *dims = info; const size_t *strides = info + num_dims; @@ -47,21 +41,18 @@ __device__ void fast_sum( // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce for (int s = blockDim.x / 2; s > 0; s >>= 1) { __syncthreads(); - if (tid < s) shr[tid] += shr[tid + s]; + if (tid < s) + shr[tid] += shr[tid + s]; } - if (tid == 0) dst[dst_id] = shr[0]; + if (tid == 0) + dst[dst_id] = shr[0]; } template <typename T> -__device__ void fast_max( - const size_t src_numel, - const size_t el_to_sum_per_block, - const size_t num_dims, - const size_t *info, - const T *src, - T *dst -) { +__device__ void +fast_max(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, T *dst) { const size_t *dims = info; const size_t *strides = info + num_dims; @@ -88,21 +79,18 @@ __device__ void fast_max( // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce for (int s = blockDim.x / 2; s > 0; s >>= 1) { __syncthreads(); - if (tid < s) shr[tid] = maxg(shr[tid], shr[tid + s]); + if (tid < s) + shr[tid] = maxg(shr[tid], shr[tid + s]); } - if (tid == 0) dst[dst_id] = shr[0]; + if (tid == 0) + dst[dst_id] = shr[0]; } template <typename T> -__device__ void fast_min( - const size_t src_numel, - const size_t el_to_sum_per_block, - const size_t num_dims, - const size_t *info, - const T *src, - T *dst -) { +__device__ void +fast_min(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, T *dst) { const size_t *dims = info; const size_t *strides = info + num_dims; @@ -129,83 +117,69 @@ __device__ void fast_min( // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce for (int s = blockDim.x / 2; s > 0; s >>= 1) { __syncthreads(); - if (tid < s) shr[tid] = ming(shr[tid], shr[tid + s]); + if (tid < s) + shr[tid] = ming(shr[tid], shr[tid + s]); } - if (tid == 0) dst[dst_id] = shr[0]; + if (tid == 0) + dst[dst_id] = shr[0]; } -#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \ -extern "C" __global__ void MIN_NAME( \ - const size_t src_numel, \ - const size_t el_to_sum_per_block, \ - const size_t num_dims, \ - const size_t *info, \ - const TYPENAME *src, \ - TYPENAME *dst \ -) { \ - fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ -} \ -extern "C" __global__ void MAX_NAME( \ - const size_t src_numel, \ - const size_t el_to_sum_per_block, \ - const size_t num_dims, \ - const size_t *info, \ - const TYPENAME *src, \ - TYPENAME *dst \ -) { \ - fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ -} \ -extern "C" __global__ void SUM_NAME( \ - const size_t src_numel, \ - const size_t el_to_sum_per_block, \ - const size_t num_dims, \ - const size_t *info, \ - const TYPENAME *src, \ - TYPENAME *dst \ -) { \ - fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ -} \ - -#define SUM_OP(TYPENAME, FN_NAME) \ -extern "C" __global__ void FN_NAME( \ - const size_t numel, \ - const size_t num_dims, \ - const size_t num_sum_dims, \ - const size_t *info, \ - const TYPENAME *inp, \ - TYPENAME *out \ -) { \ - const size_t *dims = info; \ - const size_t *strides = info + num_dims; \ - const size_t *sum_dims_l = info + 2*num_dims; \ - const size_t *sum_dims_s = info + 2*num_dims + num_sum_dims; \ - if (is_contiguous(num_dims, dims, strides)) { \ - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ - size_t dst_index = i; \ - for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ - size_t stride = sum_dims_s[nd]; \ - size_t pre = dst_index / stride; \ - size_t post = dst_index % stride; \ - dst_index = (pre / sum_dims_l[nd]) * stride + post; \ - } \ - atomicAdd(out + dst_index, inp[i]); \ - } \ - } \ - else { \ - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ - unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ - size_t dst_index = i; \ - for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ - size_t stride = sum_dims_s[nd]; \ - size_t pre = dst_index / stride; \ - size_t post = dst_index % stride; \ - dst_index = (pre / sum_dims_l[nd]) * stride + post; \ - } \ - atomicAdd(out + dst_index, inp[strided_i]); \ - } \ - } \ -} \ +#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, SUM_NAME) \ + extern "C" __global__ void MIN_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + TYPENAME *dst) { \ + fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } \ + extern "C" __global__ void MAX_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + TYPENAME *dst) { \ + fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } \ + extern "C" __global__ void SUM_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + TYPENAME *dst) { \ + fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } + +#define SUM_OP(TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const size_t numel, const size_t num_dims, const size_t num_sum_dims, \ + const size_t *info, const TYPENAME *inp, TYPENAME *out) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + const size_t *sum_dims_l = info + 2 * num_dims; \ + const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; \ + if (is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \ + i += blockDim.x * gridDim.x) { \ + size_t dst_index = i; \ + for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ + size_t stride = sum_dims_s[nd]; \ + size_t pre = dst_index / stride; \ + size_t post = dst_index % stride; \ + dst_index = (pre / sum_dims_l[nd]) * stride + post; \ + } \ + atomicAdd(out + dst_index, inp[i]); \ + } \ + } else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \ + i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + size_t dst_index = i; \ + for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ + size_t stride = sum_dims_s[nd]; \ + size_t pre = dst_index / stride; \ + size_t post = dst_index % stride; \ + dst_index = (pre / sum_dims_l[nd]) * stride + post; \ + } \ + atomicAdd(out + dst_index, inp[strided_i]); \ + } \ + } \ + } #if __CUDA_ARCH__ >= 800 SUM_OP(__nv_bfloat16, sum_bf16) |