diff options
author | Michael Feil <63565275+michaelfeil@users.noreply.github.com> | 2024-12-31 09:32:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-31 09:32:22 +0100 |
commit | 71cd6d55337b1541f602c1afffa6baf6dd75b09c (patch) | |
tree | 207e7d050e9c4bd685a563e457b2e9ef59b66f20 /candle-flash-attn/kernels/flash_fwd_kernel.h | |
parent | d60eba140820326ffc7ec39a8982e91feb462732 (diff) | |
download | candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.gz candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.bz2 candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.zip |
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1
* restore: hdim224
* add 224 flash_fwd_template
* remove whitespace
Diffstat (limited to 'candle-flash-attn/kernels/flash_fwd_kernel.h')
-rw-r--r-- | candle-flash-attn/kernels/flash_fwd_kernel.h | 30 |
1 files changed, 13 insertions, 17 deletions
diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h index 1bf77f81..b6b26d52 100644 --- a/candle-flash-attn/kernels/flash_fwd_kernel.h +++ b/candle-flash-attn/kernels/flash_fwd_kernel.h @@ -4,6 +4,8 @@ #pragma once +// #include "philox_unpack.cuh" // For at::cuda::philox::unpack + #include <cute/tensor.hpp> #include <cutlass/cutlass.h> @@ -22,14 +24,6 @@ namespace flash { using namespace cute; -template <typename Engine, typename Layout> -__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){ - #pragma unroll - for (int i = 0; i < size(tensor); ++i) { - tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); - } -} - //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN> @@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } mask.template apply_mask<Is_causal, Is_even_MN>( @@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } flash::cp_async_wait<0>(); @@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. - const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), Shape<Int<kBlockN>, Int<kHeadDim / 2>>{}, make_stride(params.rotary_dim / 2, _1{})); @@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + const index_t row_offset_knew = bidb * params.knew_batch_stride + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + const index_t row_offset_vnew = bidb * params.vnew_batch_stride + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. @@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { - const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), @@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } @@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } flash::cp_async_wait<0>(); @@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { constexpr int kBlockN = kNThreads / kBlockM; using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>; using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; |