summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/softmax.h
diff options
context:
space:
mode:
authorOlivierDehaene <Olivier.dehaene@gmail.com>2024-01-05 18:28:55 +0100
committerGitHub <noreply@github.com>2024-01-05 18:28:55 +0100
commit8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7 (patch)
tree289466c9df7a7f21ec1e574cd6cfd7b957998a08 /candle-flash-attn/kernels/softmax.h
parent3a7304cb0dbdf8ceeab8a4f5cf9b8e7ced822e20 (diff)
downloadcandle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.tar.gz
candle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.tar.bz2
candle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.zip
chore: update flash attention kernels (#1518)
* chore: update flash attention kernels * fmt * remove unused kernels * force f32 * correct stride
Diffstat (limited to 'candle-flash-attn/kernels/softmax.h')
-rw-r--r--candle-flash-attn/kernels/softmax.h57
1 files changed, 34 insertions, 23 deletions
diff --git a/candle-flash-attn/kernels/softmax.h b/candle-flash-attn/kernels/softmax.h
index 3e9a7b45..09a93f14 100644
--- a/candle-flash-attn/kernels/softmax.h
+++ b/candle-flash-attn/kernels/softmax.h
@@ -8,8 +8,7 @@
#include <cute/tensor.hpp>
-#include <cutlass/cutlass.h>
-#include <cutlass/array.h>
+#include <cutlass/numeric_types.h>
#include "philox.cuh"
#include "utils.h"
@@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
}
template <typename Engine, typename Layout>
-inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) {
+inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
+ const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
- const uint32_t lane_id = threadIdx.x % 32;
+ const int lane_id = threadIdx.x % 32;
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+ const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
- const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2;
+ const int col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results
#pragma unroll
@@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
}
}
-template <typename Engine, typename Layout>
-inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
- const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
- const uint32_t warp_row_stride) {
+template <bool HasWSLeft=true, typename Engine, typename Layout>
+inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
+ const int max_seqlen_k, const int row_idx_offset,
+ const int max_seqlen_q, const int warp_row_stride,
+ const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
- const uint32_t lane_id = threadIdx.x % 32;
- // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
- const uint32_t row_idx_offset = row_idx_offset_;
- const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+ const int lane_id = threadIdx.x % 32;
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
- const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
+ const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
- const uint32_t row_idx = row_idx_base + i * 8;
- const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
+ const int row_idx = row_idx_base + i * 8;
+ const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
+ const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
- const uint32_t col_idx_base = col_idx_offset + nj * 8;
+ const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
- const uint32_t col_idx = col_idx_base + j;
- if (col_idx >= col_idx_limit) {
+ const int col_idx = col_idx_base + j;
+ if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
@@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
}
}
+template <typename Engine, typename Layout>
+inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
+ const int max_seqlen_k, const int row_idx_offset,
+ const int max_seqlen_q, const int warp_row_stride) {
+ // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
+ apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
+ max_seqlen_q, warp_row_stride, -1, 0);
+}
+
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
- const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_)
+ const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
@@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
- const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
+ const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
#pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
@@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx(
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
unsigned long long seed, unsigned long long offset,
- uint32_t block_row_start, uint32_t block_col_start,
- uint32_t block_row_stride) {
+ int block_row_start, int block_col_start,
+ int block_row_stride) {
// tensor has shape (8, MMA_M, MMA_N / 2)
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {