From 71cd6d55337b1541f602c1afffa6baf6dd75b09c Mon Sep 17 00:00:00 2001
From: Michael Feil <63565275+michaelfeil@users.noreply.github.com>
Date: Tue, 31 Dec 2024 09:32:22 +0100
Subject: Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)

* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace
---
 candle-flash-attn/build.rs                         |  1 +
 candle-flash-attn/cutlass                          |  2 +-
 candle-flash-attn/kernels/block_info.h             |  8 +++--
 candle-flash-attn/kernels/flash.h                  | 13 +++----
 .../kernels/flash_fwd_hdim128_bf16_causal_sm80.cu  |  2 +-
 .../kernels/flash_fwd_hdim128_bf16_sm80.cu         |  2 +-
 .../kernels/flash_fwd_hdim128_fp16_causal_sm80.cu  |  2 +-
 .../kernels/flash_fwd_hdim128_fp16_sm80.cu         |  2 +-
 .../kernels/flash_fwd_hdim160_bf16_causal_sm80.cu  |  2 +-
 .../kernels/flash_fwd_hdim160_bf16_sm80.cu         |  2 +-
 .../kernels/flash_fwd_hdim160_fp16_causal_sm80.cu  |  2 +-
 .../kernels/flash_fwd_hdim160_fp16_sm80.cu         |  2 +-
 .../kernels/flash_fwd_hdim192_bf16_causal_sm80.cu  |  2 +-
 .../kernels/flash_fwd_hdim192_bf16_sm80.cu         |  2 +-
 .../kernels/flash_fwd_hdim192_fp16_causal_sm80.cu  |  2 +-
 .../kernels/flash_fwd_hdim192_fp16_sm80.cu         |  2 +-
 .../kernels/flash_fwd_hdim224_bf16_causal_sm80.cu  |  2 +-
 .../kernels/flash_fwd_hdim224_bf16_sm80.cu         |  2 +-
 .../kernels/flash_fwd_hdim224_fp16_causal_sm80.cu  |  2 +-
 .../kernels/flash_fwd_hdim224_fp16_sm80.cu         |  2 +-
 .../kernels/flash_fwd_hdim256_bf16_causal_sm80.cu  |  2 +-
 .../kernels/flash_fwd_hdim256_bf16_sm80.cu         |  2 +-
 .../kernels/flash_fwd_hdim256_fp16_causal_sm80.cu  |  2 +-
 .../kernels/flash_fwd_hdim256_fp16_sm80.cu         |  2 +-
 .../kernels/flash_fwd_hdim32_bf16_causal_sm80.cu   |  2 +-
 .../kernels/flash_fwd_hdim32_bf16_sm80.cu          |  2 +-
 .../kernels/flash_fwd_hdim32_fp16_causal_sm80.cu   |  2 +-
 .../kernels/flash_fwd_hdim32_fp16_sm80.cu          |  2 +-
 .../kernels/flash_fwd_hdim64_bf16_causal_sm80.cu   |  2 +-
 .../kernels/flash_fwd_hdim64_bf16_sm80.cu          |  2 +-
 .../kernels/flash_fwd_hdim64_fp16_causal_sm80.cu   |  2 +-
 .../kernels/flash_fwd_hdim64_fp16_sm80.cu          |  2 +-
 .../kernels/flash_fwd_hdim96_bf16_causal_sm80.cu   |  2 +-
 .../kernels/flash_fwd_hdim96_bf16_sm80.cu          |  2 +-
 .../kernels/flash_fwd_hdim96_fp16_causal_sm80.cu   |  2 +-
 .../kernels/flash_fwd_hdim96_fp16_sm80.cu          |  2 +-
 candle-flash-attn/kernels/flash_fwd_kernel.h       | 30 +++++++---------
 .../kernels/flash_fwd_launch_template.h            | 15 ++++----
 candle-flash-attn/kernels/hardware_info.h          | 42 ++++++++++++++++++++++
 candle-flash-attn/kernels/kernel_traits.h          | 30 ++++++++--------
 candle-flash-attn/kernels/utils.h                  | 18 ++++++++++
 41 files changed, 140 insertions(+), 83 deletions(-)
 create mode 100644 candle-flash-attn/kernels/hardware_info.h

diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index 53fec5de..37247646 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -54,6 +54,7 @@ fn main() -> Result<()> {
     println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
     println!("cargo:rerun-if-changed=kernels/block_info.h");
     println!("cargo:rerun-if-changed=kernels/static_switch.h");
+    println!("cargo:rerun-if-changed=kernels/hardware_info.h");
     let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
     let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
         Err(_) =>
diff --git a/candle-flash-attn/cutlass b/candle-flash-attn/cutlass
index 7d49e6c7..4c42f73f 160000
--- a/candle-flash-attn/cutlass
+++ b/candle-flash-attn/cutlass
@@ -1 +1 @@
-Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc
+Subproject commit 4c42f73fdab5787e3bb57717f35a8cb1b3c0dc6d
diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h
index 3a23a1e1..cf60d653 100644
--- a/candle-flash-attn/kernels/block_info.h
+++ b/candle-flash-attn/kernels/block_info.h
@@ -18,8 +18,9 @@ struct BlockInfo {
         , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
         // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
         // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
-        , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
-        , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
+        , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
+        , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
+        , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
         {
         }
 
@@ -30,13 +31,14 @@ struct BlockInfo {
 
     template <typename index_t>
     __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
-        return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
+        return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
     }
 
     const int sum_s_q;
     const int sum_s_k;
     const int actual_seqlen_q;
     // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
+    const int leftpad_k;
     const int seqlen_k_cache;
     const int actual_seqlen_k;
 };
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h
index 88c2f22a..f21e4d62 100644
--- a/candle-flash-attn/kernels/flash.h
+++ b/candle-flash-attn/kernels/flash.h
@@ -7,13 +7,7 @@
 #include <cuda.h>
 #include <vector>
 
-// #ifdef OLD_GENERATOR_PATH
-// #include <ATen/CUDAGeneratorImpl.h>
-// #else
-// #include <ATen/cuda/CUDAGeneratorImpl.h>
-// #endif
-// 
-// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
+// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
 
 constexpr int TOTAL_DIM = 0;
 constexpr int H_DIM = 1;
@@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params {
     // array of length b+1 holding starting offset of each sequence.
     int * __restrict__ cu_seqlens_q;
     int * __restrict__ cu_seqlens_k;
+    int * __restrict__ leftpad_k;
 
     // If provided, the actual length of each k sequence.
     int * __restrict__ seqused_k;
@@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
+// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
 
-template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
+// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu
index f19049b4..9383c102 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
index cb135741..f03abda4 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu
index dfb04b78..c616628c 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
index 6df16b2c..4ff6b9fb 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu
index 230af906..d6d4371b 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
index cf1ffad2..5af68ac3 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu
index 1fc5ac59..1ef511a6 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
index a9796ade..96abfbd8 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu
index 94792d4d..077d25d0 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
index 76d5136b..ea5f265f 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu
index 9e5b21e0..a4a7bc24 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
index b4019a0b..c30c4a14 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu
index a12a5f4a..db69f21c 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
index 8690bdb1..9a11724b 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu
index f01dad09..d02edae0 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
index 7ec1e16b..28150ed0 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu
index 3d816ab6..f84e978c 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
index c6c55229..c52f0417 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu
index 0149abac..f96f7edc 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
index 9c9a1715..9c7c6b93 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu
index 29097ac3..e21d0408 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
index cb52f34f..f377a5b8 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu
index 7bdadefb..74e4d66a 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
index 44b38816..e85db18e 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu
index 99cd728b..9297e8bb 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
index c11096ac..8364b1e7 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu
index 2fbcd44e..1c6ed7ef 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
index 7b65a9c9..3c87573b 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu
index 6fb3cf64..49fae856 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
index e696b2f2..c5af1cf6 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu
index bb3b744d..b0d6c992 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
index 5f3accc3..c97aa33f 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h
index 1bf77f81..b6b26d52 100644
--- a/candle-flash-attn/kernels/flash_fwd_kernel.h
+++ b/candle-flash-attn/kernels/flash_fwd_kernel.h
@@ -4,6 +4,8 @@
 
 #pragma once
 
+// #include "philox_unpack.cuh" // For at::cuda::philox::unpack
+
 #include <cute/tensor.hpp>
 
 #include <cutlass/cutlass.h>
@@ -22,14 +24,6 @@ namespace flash {
 
 using namespace cute;
 
-template <typename Engine, typename Layout>
-__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
-    #pragma unroll
-    for (int i = 0; i < size(tensor); ++i) {
-        tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
-    }
-}
-
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
@@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
         );
         // if (cute::thread0()) { print(acc_s); }
         if constexpr (Is_softcap){
-            apply_softcap(acc_s, params.softcap);
+            flash::apply_softcap(acc_s, params.softcap);
         }
 
         mask.template apply_mask<Is_causal, Is_even_MN>(
@@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
             smem_thr_copy_Q, smem_thr_copy_K
         );
         if constexpr (Is_softcap){
-            apply_softcap(acc_s, params.softcap);
+            flash::apply_softcap(acc_s, params.softcap);
         }
 
         flash::cp_async_wait<0>();
@@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
         // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
         // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
-        const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
+        const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
         Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
                                   Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
                                   make_stride(params.rotary_dim / 2, _1{}));
@@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         // if (cute::thread(8, 0)) { print_tensor(gCos); }
         // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
 
-        const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+        // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+        const index_t row_offset_knew = bidb * params.knew_batch_stride
             + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
-        const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+        // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+        const index_t row_offset_vnew = bidb * params.vnew_batch_stride
             + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
         // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
         // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
@@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
                                            binfo.actual_seqlen_q - m_block * kBlockM);
     } else {
-        const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
+        const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
         // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
         // We do this by setting the row stride of gCos / gSin to 0.
         Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
@@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         );
         // if (cute::thread0()) { print(acc_s); }
         if constexpr (Is_softcap){
-            apply_softcap(acc_s, params.softcap);
+            flash::apply_softcap(acc_s, params.softcap);
         }
 
 
@@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
             smem_thr_copy_Q, smem_thr_copy_K
         );
         if constexpr (Is_softcap){
-            apply_softcap(acc_s, params.softcap);
+            flash::apply_softcap(acc_s, params.softcap);
         }
 
         flash::cp_async_wait<0>();
@@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
     constexpr int kBlockN = kNThreads / kBlockM;
     using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
     using GmemTiledCopyOaccum = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
                         GmemLayoutAtomOaccum{},
                         Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store
     GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h
index 9e5449d7..bb581eb3 100644
--- a/candle-flash-attn/kernels/flash_fwd_launch_template.h
+++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h
@@ -3,11 +3,11 @@
  ******************************************************************************/
 
 #pragma once
-
-// #include <ATen/cuda/CUDAContext.h>
+// #include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
 
 #include "error.h"
 #include "static_switch.h"
+#include "hardware_info.h"
 #include "flash.h"
 #include "flash_fwd_kernel.h"
 
@@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
                             // If return_softmax, set IsEvenMNConst to false to reduce number of templates
                             // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                             // If Is_local, set Is_causal to false
-                            auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
+                            auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
                             // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
                             // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
                             // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
@@ -205,7 +205,8 @@ inline bool cuda_is_sm8x() {
 template<typename T, bool Is_causal>
 void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 96;
-    bool is_sm8x = cuda_is_sm8x();
+    auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
+    bool is_sm8x = cc_major == 8 && cc_minor > 0;
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
         if (is_sm8x) {
@@ -228,7 +229,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
 template<typename T, bool Is_causal>
 void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 128;
-    bool is_sm8x = cuda_is_sm8x();
+    auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
+    bool is_sm8x = cc_major == 8 && cc_minor > 0;
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if constexpr(!Is_dropout) {
             // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@@ -262,7 +264,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
 template<typename T, bool Is_causal>
 void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 160;
-    bool is_sm8x = cuda_is_sm8x();
+    auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
+    bool is_sm8x = cc_major == 8 && cc_minor > 0;
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         // For A100, H100, 128 x 32 is the fastest.
         // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
diff --git a/candle-flash-attn/kernels/hardware_info.h b/candle-flash-attn/kernels/hardware_info.h
new file mode 100644
index 00000000..d5c48d35
--- /dev/null
+++ b/candle-flash-attn/kernels/hardware_info.h
@@ -0,0 +1,42 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <tuple>
+#include <cstdio>
+
+#if !defined(__CUDACC_RTC__)
+#include "cuda_runtime.h"
+#endif
+
+#define CHECK_CUDA(call)                                                       \
+  do {                                                                         \
+    cudaError_t status_ = call;                                                \
+    if (status_ != cudaSuccess) {                                              \
+      fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__,          \
+              cudaGetErrorString(status_));                                    \
+      exit(1);                                                                 \
+    }                                                                          \
+  } while (0)
+
+
+inline int get_current_device() {
+    int device;
+    CHECK_CUDA(cudaGetDevice(&device));
+    return device;
+}
+
+inline std::tuple<int, int> get_compute_capability(int device) {
+    int capability_major, capability_minor;
+    CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device));
+    CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device));
+    return {capability_major, capability_minor};
+}
+
+inline int get_num_sm(int device) {
+    int multiprocessor_count;
+    CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
+    return multiprocessor_count;
+}
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h
index 5a7b7491..8c089748 100644
--- a/candle-flash-attn/kernels/kernel_traits.h
+++ b/candle-flash-attn/kernels/kernel_traits.h
@@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
     using SmemLayoutO = decltype(tile_to_shape(
         SmemLayoutAtomO{},
         Shape<Int<kBlockM>, Int<kHeadDim>>{}));
-    using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
-    using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
+    using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
+    using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
 
     static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
     static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
@@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
     using Gmem_copy_struct = std::conditional_t<
         Has_cp_async,
         SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
-        DefaultCopy
+        AutoVectorizingCopyWithAssumedAlignment<128>
     >;
     using GmemTiledCopyQKV = decltype(
         make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
                         GmemLayoutAtom{},
                         Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per read
     using GmemTiledCopyO = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
                         GmemLayoutAtom{},
                         Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per store
 
@@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
                Stride< _16, _1>>
     >;
     using GmemTiledCopyOaccum = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
                         GmemLayoutAtomOaccum{},
                         Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store
     using GmemLayoutAtomRotcossin = GmemLayoutAtom;
@@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
                         GmemLayoutAtomRotcossin{},
                         Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per load
     using GmemTiledCopyRotcossinCont = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
                         GmemLayoutAtomRotcossin{},
                         Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per load
 };
@@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
         composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
     using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
 
-    using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
+    using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
 
     using SmemLayoutQdOtransposed = decltype(
         composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
@@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
     using SmemLayoutdKV = decltype(tile_to_shape(
         SmemLayoutAtomdKV{},
         make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
-    using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
+    using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
 
     using SmemLayoutAtomdQ = decltype(
         composition(Swizzle<kSwizzle, 3, 3>{},
@@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
     using SmemLayoutdQ = decltype(tile_to_shape(
         SmemLayoutAtomdQ{},
         make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
-    using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
+    using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
 
     // Double buffer for sQ
     static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
@@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
     using Gmem_copy_struct = std::conditional_t<
         Has_cp_async,
         SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
-        DefaultCopy
+        AutoVectorizingCopyWithAssumedAlignment<128>
     >;
     using GmemTiledCopyQKV = decltype(
         make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
                         GmemLayoutAtom{},
                         Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per read
     using GmemTiledCopydO = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
                         GmemLayoutAtom{},
                         Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per store
     using GmemTiledCopydKV = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
                         GmemLayoutAtom{},
                         Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per store
     using GmemTiledCopydQ = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
                         GmemLayoutAtom{},
                         Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per store
     using GmemLayoutAtomdQaccum = std::conditional_t<
@@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
                Stride< _16, _1>>
     >;
     using GmemTiledCopydQaccum = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
                         GmemLayoutAtomdQaccum{},
                         Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store
 
     using GmemTiledCopydQaccumAtomicAdd = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
                         Layout<Shape <_8, _32>,  // Thread layout, 8 threads per row
                                Stride<_32, _1>>{},
                         Layout<Shape < _1, _1>>{}));  // Val layout, 1 val per store
diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h
index 708aeddf..b7408ec4 100644
--- a/candle-flash-attn/kernels/utils.h
+++ b/candle-flash-attn/kernels/utils.h
@@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
+    #pragma unroll
+    for (int i = 0; i < size(tensor); ++i) {
+        tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
+    }
+}
+
+template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
+    #pragma unroll
+    for (int i = 0; i < size(src_tensor); ++i) {
+        dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
 }  // namespace flash
-- 
cgit v1.2.3