From 71cd6d55337b1541f602c1afffa6baf6dd75b09c Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Tue, 31 Dec 2024 09:32:22 +0100 Subject: Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688) * update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace --- candle-flash-attn/kernels/flash_fwd_kernel.h | 30 ++++++++++++---------------- 1 file changed, 13 insertions(+), 17 deletions(-) (limited to 'candle-flash-attn/kernels/flash_fwd_kernel.h') diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h index 1bf77f81..b6b26d52 100644 --- a/candle-flash-attn/kernels/flash_fwd_kernel.h +++ b/candle-flash-attn/kernels/flash_fwd_kernel.h @@ -4,6 +4,8 @@ #pragma once +// #include "philox_unpack.cuh" // For at::cuda::philox::unpack + #include #include @@ -22,14 +24,6 @@ namespace flash { using namespace cute; -template -__forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ - #pragma unroll - for (int i = 0; i < size(tensor); ++i) { - tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); - } -} - //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( @@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } flash::cp_async_wait<0>(); @@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. - const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); @@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + const index_t row_offset_knew = bidb * params.knew_batch_stride + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + const index_t row_offset_vnew = bidb * params.vnew_batch_stride + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. @@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { - const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), @@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } @@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } flash::cp_async_wait<0>(); @@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { constexpr int kBlockN = kNThreads / kBlockM; using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; -- cgit v1.2.3