summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash_fwd_kernel.h
diff options
context:
space:
mode:
authorMichael Feil <63565275+michaelfeil@users.noreply.github.com>2024-12-31 09:32:22 +0100
committerGitHub <noreply@github.com>2024-12-31 09:32:22 +0100
commit71cd6d55337b1541f602c1afffa6baf6dd75b09c (patch)
tree207e7d050e9c4bd685a563e457b2e9ef59b66f20 /candle-flash-attn/kernels/flash_fwd_kernel.h
parentd60eba140820326ffc7ec39a8982e91feb462732 (diff)
downloadcandle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.gz
candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.bz2
candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.zip
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
Diffstat (limited to 'candle-flash-attn/kernels/flash_fwd_kernel.h')
-rw-r--r--candle-flash-attn/kernels/flash_fwd_kernel.h30
1 files changed, 13 insertions, 17 deletions
diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h
index 1bf77f81..b6b26d52 100644
--- a/candle-flash-attn/kernels/flash_fwd_kernel.h
+++ b/candle-flash-attn/kernels/flash_fwd_kernel.h
@@ -4,6 +4,8 @@
#pragma once
+// #include "philox_unpack.cuh" // For at::cuda::philox::unpack
+
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
@@ -22,14 +24,6 @@ namespace flash {
using namespace cute;
-template <typename Engine, typename Layout>
-__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
- #pragma unroll
- for (int i = 0; i < size(tensor); ++i) {
- tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
- }
-}
-
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
@@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
- apply_softcap(acc_s, params.softcap);
+ flash::apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
@@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
- apply_softcap(acc_s, params.softcap);
+ flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
@@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
- const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
+ const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
make_stride(params.rotary_dim / 2, _1{}));
@@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
- const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+ // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+ const index_t row_offset_knew = bidb * params.knew_batch_stride
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
- const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+ // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+ const index_t row_offset_vnew = bidb * params.vnew_batch_stride
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
@@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
} else {
- const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
+ const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
@@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
- apply_softcap(acc_s, params.softcap);
+ flash::apply_softcap(acc_s, params.softcap);
}
@@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
- apply_softcap(acc_s, params.softcap);
+ flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
@@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
constexpr int kBlockN = kNThreads / kBlockM;
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
using GmemTiledCopyOaccum = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;