diff options
Diffstat (limited to 'candle-flash-attn/kernels/utils.h')
-rw-r--r-- | candle-flash-attn/kernels/utils.h | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h index 708aeddf..b7408ec4 100644 --- a/candle-flash-attn/kernels/utils.h +++ b/candle-flash-attn/kernels/utils.h @@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S //////////////////////////////////////////////////////////////////////////////////////////////////// +template <typename Engine, typename Layout> +__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> +__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(src_tensor); ++i) { + dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash |