diff options
author | OlivierDehaene <Olivier.dehaene@gmail.com> | 2024-01-05 18:28:55 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-05 18:28:55 +0100 |
commit | 8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7 (patch) | |
tree | 289466c9df7a7f21ec1e574cd6cfd7b957998a08 /candle-flash-attn/kernels/softmax.h | |
parent | 3a7304cb0dbdf8ceeab8a4f5cf9b8e7ced822e20 (diff) | |
download | candle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.tar.gz candle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.tar.bz2 candle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.zip |
chore: update flash attention kernels (#1518)
* chore: update flash attention kernels
* fmt
* remove unused kernels
* force f32
* correct stride
Diffstat (limited to 'candle-flash-attn/kernels/softmax.h')
-rw-r--r-- | candle-flash-attn/kernels/softmax.h | 57 |
1 files changed, 34 insertions, 23 deletions
diff --git a/candle-flash-attn/kernels/softmax.h b/candle-flash-attn/kernels/softmax.h index 3e9a7b45..09a93f14 100644 --- a/candle-flash-attn/kernels/softmax.h +++ b/candle-flash-attn/kernels/softmax.h @@ -8,8 +8,7 @@ #include <cute/tensor.hpp> -#include <cutlass/cutlass.h> -#include <cutlass/array.h> +#include <cutlass/numeric_types.h> #include "philox.cuh" #include "utils.h" @@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens } template <typename Engine, typename Layout> -inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) { +inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; + const int col_idx = col_idx_base + j; if (col_idx >= max_seqlen_k) { // Without the "make_coord" we get wrong results #pragma unroll @@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t } } -template <typename Engine, typename Layout> -inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_, - const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, - const uint32_t warp_row_stride) { +template <bool HasWSLeft=true, typename Engine, typename Layout> +inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; - // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; - const uint32_t row_idx_offset = row_idx_offset_; - const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { - const uint32_t row_idx = row_idx_base + i * 8; - const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const uint32_t col_idx_base = col_idx_offset + nj * 8; + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u } } +template <typename Engine, typename Layout> +inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> inline __device__ void apply_mask_causal_w_idx( Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, - const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout0::rank == 2, "Only support 2D Tensor"); @@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx( CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { - const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { @@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx( template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout> inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t, unsigned long long seed, unsigned long long offset, - uint32_t block_row_start, uint32_t block_col_start, - uint32_t block_row_stride) { + int block_row_start, int block_col_start, + int block_row_stride) { // tensor has shape (8, MMA_M, MMA_N / 2) using T = typename Engine::value_type; auto encode_dropout = [](bool keep, T val) { |