summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/utils.h
diff options
context:
space:
mode:
authorMichael Feil <63565275+michaelfeil@users.noreply.github.com>2024-12-31 09:32:22 +0100
committerGitHub <noreply@github.com>2024-12-31 09:32:22 +0100
commit71cd6d55337b1541f602c1afffa6baf6dd75b09c (patch)
tree207e7d050e9c4bd685a563e457b2e9ef59b66f20 /candle-flash-attn/kernels/utils.h
parentd60eba140820326ffc7ec39a8982e91feb462732 (diff)
downloadcandle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.gz
candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.bz2
candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.zip
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
Diffstat (limited to 'candle-flash-attn/kernels/utils.h')
-rw-r--r--candle-flash-attn/kernels/utils.h18
1 files changed, 18 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h
index 708aeddf..b7408ec4 100644
--- a/candle-flash-attn/kernels/utils.h
+++ b/candle-flash-attn/kernels/utils.h
@@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
////////////////////////////////////////////////////////////////////////////////////////////////////
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
+ #pragma unroll
+ for (int i = 0; i < size(tensor); ++i) {
+ tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
+ }
+}
+
+template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
+ #pragma unroll
+ for (int i = 0; i < size(src_tensor); ++i) {
+ dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
} // namespace flash