summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/utils.h')
-rw-r--r--candle-flash-attn/kernels/utils.h18
1 files changed, 18 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h
index 708aeddf..b7408ec4 100644
--- a/candle-flash-attn/kernels/utils.h
+++ b/candle-flash-attn/kernels/utils.h
@@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
////////////////////////////////////////////////////////////////////////////////////////////////////
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
+ #pragma unroll
+ for (int i = 0; i < size(tensor); ++i) {
+ tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
+ }
+}
+
+template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
+ #pragma unroll
+ for (int i = 0; i < size(src_tensor); ++i) {
+ dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
} // namespace flash