From 71cd6d55337b1541f602c1afffa6baf6dd75b09c Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Tue, 31 Dec 2024 09:32:22 +0100 Subject: Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688) * update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace --- candle-flash-attn/kernels/utils.h | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) (limited to 'candle-flash-attn/kernels/utils.h') diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h index 708aeddf..b7408ec4 100644 --- a/candle-flash-attn/kernels/utils.h +++ b/candle-flash-attn/kernels/utils.h @@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor const &S //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template +__forceinline__ __device__ void calculate_dtanh(Tensor &src_tensor, Tensor &dst_tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(src_tensor); ++i) { + dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash -- cgit v1.2.3