summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Feil <63565275+michaelfeil@users.noreply.github.com>2024-12-31 09:32:22 +0100
committerGitHub <noreply@github.com>2024-12-31 09:32:22 +0100
commit71cd6d55337b1541f602c1afffa6baf6dd75b09c (patch)
tree207e7d050e9c4bd685a563e457b2e9ef59b66f20
parentd60eba140820326ffc7ec39a8982e91feb462732 (diff)
downloadcandle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.gz
candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.bz2
candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.zip
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
-rw-r--r--candle-flash-attn/build.rs1
m---------candle-flash-attn/cutlass0
-rw-r--r--candle-flash-attn/kernels/block_info.h8
-rw-r--r--candle-flash-attn/kernels/flash.h13
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu2
-rw-r--r--candle-flash-attn/kernels/flash_fwd_kernel.h30
-rw-r--r--candle-flash-attn/kernels/flash_fwd_launch_template.h15
-rw-r--r--candle-flash-attn/kernels/hardware_info.h42
-rw-r--r--candle-flash-attn/kernels/kernel_traits.h30
-rw-r--r--candle-flash-attn/kernels/utils.h18
41 files changed, 139 insertions, 82 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index 53fec5de..37247646 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -54,6 +54,7 @@ fn main() -> Result<()> {
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
println!("cargo:rerun-if-changed=kernels/block_info.h");
println!("cargo:rerun-if-changed=kernels/static_switch.h");
+ println!("cargo:rerun-if-changed=kernels/hardware_info.h");
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
Err(_) =>
diff --git a/candle-flash-attn/cutlass b/candle-flash-attn/cutlass
-Subproject 7d49e6c7e2f8896c47f586706e67e1fb215529d
+Subproject 4c42f73fdab5787e3bb57717f35a8cb1b3c0dc6
diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h
index 3a23a1e1..cf60d653 100644
--- a/candle-flash-attn/kernels/block_info.h
+++ b/candle-flash-attn/kernels/block_info.h
@@ -18,8 +18,9 @@ struct BlockInfo {
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
- , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
- , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
+ , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
+ , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
+ , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
@@ -30,13 +31,14 @@ struct BlockInfo {
template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
- return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
+ return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
+ const int leftpad_k;
const int seqlen_k_cache;
const int actual_seqlen_k;
};
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h
index 88c2f22a..f21e4d62 100644
--- a/candle-flash-attn/kernels/flash.h
+++ b/candle-flash-attn/kernels/flash.h
@@ -7,13 +7,7 @@
#include <cuda.h>
#include <vector>
-// #ifdef OLD_GENERATOR_PATH
-// #include <ATen/CUDAGeneratorImpl.h>
-// #else
-// #include <ATen/cuda/CUDAGeneratorImpl.h>
-// #endif
-//
-// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
+// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
@@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
+ int * __restrict__ leftpad_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
@@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
+// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
+// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu
index f19049b4..9383c102 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
index cb135741..f03abda4 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu
index dfb04b78..c616628c 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
index 6df16b2c..4ff6b9fb 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu
index 230af906..d6d4371b 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
index cf1ffad2..5af68ac3 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu
index 1fc5ac59..1ef511a6 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
index a9796ade..96abfbd8 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu
index 94792d4d..077d25d0 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
index 76d5136b..ea5f265f 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu
index 9e5b21e0..a4a7bc24 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
index b4019a0b..c30c4a14 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu
index a12a5f4a..db69f21c 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
index 8690bdb1..9a11724b 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu
index f01dad09..d02edae0 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
index 7ec1e16b..28150ed0 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu
index 3d816ab6..f84e978c 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
index c6c55229..c52f0417 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu
index 0149abac..f96f7edc 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
index 9c9a1715..9c7c6b93 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu
index 29097ac3..e21d0408 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
index cb52f34f..f377a5b8 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu
index 7bdadefb..74e4d66a 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
index 44b38816..e85db18e 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu
index 99cd728b..9297e8bb 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
index c11096ac..8364b1e7 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu
index 2fbcd44e..1c6ed7ef 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
index 7b65a9c9..3c87573b 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu
index 6fb3cf64..49fae856 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
index e696b2f2..c5af1cf6 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu
index bb3b744d..b0d6c992 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
index 5f3accc3..c97aa33f 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h
index 1bf77f81..b6b26d52 100644
--- a/candle-flash-attn/kernels/flash_fwd_kernel.h
+++ b/candle-flash-attn/kernels/flash_fwd_kernel.h
@@ -4,6 +4,8 @@
#pragma once
+// #include "philox_unpack.cuh" // For at::cuda::philox::unpack
+
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
@@ -22,14 +24,6 @@ namespace flash {
using namespace cute;
-template <typename Engine, typename Layout>
-__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
- #pragma unroll
- for (int i = 0; i < size(tensor); ++i) {
- tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
- }
-}
-
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
@@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
- apply_softcap(acc_s, params.softcap);
+ flash::apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
@@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
- apply_softcap(acc_s, params.softcap);
+ flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
@@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
- const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
+ const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
make_stride(params.rotary_dim / 2, _1{}));
@@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
- const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+ // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+ const index_t row_offset_knew = bidb * params.knew_batch_stride
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
- const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+ // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+ const index_t row_offset_vnew = bidb * params.vnew_batch_stride
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
@@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
} else {
- const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
+ const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
@@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
- apply_softcap(acc_s, params.softcap);
+ flash::apply_softcap(acc_s, params.softcap);
}
@@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
- apply_softcap(acc_s, params.softcap);
+ flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
@@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
constexpr int kBlockN = kNThreads / kBlockM;
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
using GmemTiledCopyOaccum = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h
index 9e5449d7..bb581eb3 100644
--- a/candle-flash-attn/kernels/flash_fwd_launch_template.h
+++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h
@@ -3,11 +3,11 @@
******************************************************************************/
#pragma once
-
-// #include <ATen/cuda/CUDAContext.h>
+// #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include "error.h"
#include "static_switch.h"
+#include "hardware_info.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
@@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
- auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
+ auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
@@ -205,7 +205,8 @@ inline bool cuda_is_sm8x() {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
- bool is_sm8x = cuda_is_sm8x();
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
+ bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) {
@@ -228,7 +229,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
- bool is_sm8x = cuda_is_sm8x();
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
+ bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if constexpr(!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@@ -262,7 +264,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
- bool is_sm8x = cuda_is_sm8x();
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
+ bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
diff --git a/candle-flash-attn/kernels/hardware_info.h b/candle-flash-attn/kernels/hardware_info.h
new file mode 100644
index 00000000..d5c48d35
--- /dev/null
+++ b/candle-flash-attn/kernels/hardware_info.h
@@ -0,0 +1,42 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <tuple>
+#include <cstdio>
+
+#if !defined(__CUDACC_RTC__)
+#include "cuda_runtime.h"
+#endif
+
+#define CHECK_CUDA(call) \
+ do { \
+ cudaError_t status_ = call; \
+ if (status_ != cudaSuccess) { \
+ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \
+ cudaGetErrorString(status_)); \
+ exit(1); \
+ } \
+ } while (0)
+
+
+inline int get_current_device() {
+ int device;
+ CHECK_CUDA(cudaGetDevice(&device));
+ return device;
+}
+
+inline std::tuple<int, int> get_compute_capability(int device) {
+ int capability_major, capability_minor;
+ CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device));
+ CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device));
+ return {capability_major, capability_minor};
+}
+
+inline int get_num_sm(int device) {
+ int multiprocessor_count;
+ CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
+ return multiprocessor_count;
+}
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h
index 5a7b7491..8c089748 100644
--- a/candle-flash-attn/kernels/kernel_traits.h
+++ b/candle-flash-attn/kernels/kernel_traits.h
@@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
- using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
- using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
+ using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
+ using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
@@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
- DefaultCopy
+ AutoVectorizingCopyWithAssumedAlignment<128>
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
@@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
@@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
};
@@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
- using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemLayoutQdOtransposed = decltype(
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
@@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutdKV = decltype(tile_to_shape(
SmemLayoutAtomdKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
- using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemLayoutAtomdQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutdQ = decltype(tile_to_shape(
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
- using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
// Double buffer for sQ
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
@@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
- DefaultCopy
+ AutoVectorizingCopyWithAssumedAlignment<128>
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopydO = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydKV = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydQ = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomdQaccum = std::conditional_t<
@@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopydQaccum = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomdQaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemTiledCopydQaccumAtomicAdd = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
Stride<_32, _1>>{},
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h
index 708aeddf..b7408ec4 100644
--- a/candle-flash-attn/kernels/utils.h
+++ b/candle-flash-attn/kernels/utils.h
@@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
////////////////////////////////////////////////////////////////////////////////////////////////////
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
+ #pragma unroll
+ for (int i = 0; i < size(tensor); ++i) {
+ tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
+ }
+}
+
+template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
+ #pragma unroll
+ for (int i = 0; i < size(src_tensor); ++i) {
+ dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
} // namespace flash