diff options
author | Michael Feil <63565275+michaelfeil@users.noreply.github.com> | 2024-12-31 09:32:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-31 09:32:22 +0100 |
commit | 71cd6d55337b1541f602c1afffa6baf6dd75b09c (patch) | |
tree | 207e7d050e9c4bd685a563e457b2e9ef59b66f20 | |
parent | d60eba140820326ffc7ec39a8982e91feb462732 (diff) | |
download | candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.gz candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.bz2 candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.zip |
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1
* restore: hdim224
* add 224 flash_fwd_template
* remove whitespace
41 files changed, 139 insertions, 82 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 53fec5de..37247646 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -54,6 +54,7 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=kernels/kernel_traits.h"); println!("cargo:rerun-if-changed=kernels/block_info.h"); println!("cargo:rerun-if-changed=kernels/static_switch.h"); + println!("cargo:rerun-if-changed=kernels/hardware_info.h"); let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") { Err(_) => diff --git a/candle-flash-attn/cutlass b/candle-flash-attn/cutlass -Subproject 7d49e6c7e2f8896c47f586706e67e1fb215529d +Subproject 4c42f73fdab5787e3bb57717f35a8cb1b3c0dc6 diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h index 3a23a1e1..cf60d653 100644 --- a/candle-flash-attn/kernels/block_info.h +++ b/candle-flash-attn/kernels/block_info.h @@ -18,8 +18,9 @@ struct BlockInfo { , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } @@ -30,13 +31,14 @@ struct BlockInfo { template <typename index_t> __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; } const int sum_s_q; const int sum_s_k; const int actual_seqlen_q; // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int leftpad_k; const int seqlen_k_cache; const int actual_seqlen_k; }; diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h index 88c2f22a..f21e4d62 100644 --- a/candle-flash-attn/kernels/flash.h +++ b/candle-flash-attn/kernels/flash.h @@ -7,13 +7,7 @@ #include <cuda.h> #include <vector> -// #ifdef OLD_GENERATOR_PATH -// #include <ATen/CUDAGeneratorImpl.h> -// #else -// #include <ATen/cuda/CUDAGeneratorImpl.h> -// #endif -// -// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack +// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; @@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params { // array of length b+1 holding starting offset of each sequence. int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; + int * __restrict__ leftpad_k; // If provided, the actual length of each k sequence. int * __restrict__ seqused_k; @@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu index f19049b4..9383c102 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu index cb135741..f03abda4 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu index dfb04b78..c616628c 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu index 6df16b2c..4ff6b9fb 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu index 230af906..d6d4371b 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu index cf1ffad2..5af68ac3 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu index 1fc5ac59..1ef511a6 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu index a9796ade..96abfbd8 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu index 94792d4d..077d25d0 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu index 76d5136b..ea5f265f 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu index 9e5b21e0..a4a7bc24 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu index b4019a0b..c30c4a14 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu index a12a5f4a..db69f21c 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu index 8690bdb1..9a11724b 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu index f01dad09..d02edae0 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu index 7ec1e16b..28150ed0 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu index 3d816ab6..f84e978c 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu index c6c55229..c52f0417 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu index 0149abac..f96f7edc 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu index 9c9a1715..9c7c6b93 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu index 29097ac3..e21d0408 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu index cb52f34f..f377a5b8 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu index 7bdadefb..74e4d66a 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu index 44b38816..e85db18e 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu index 99cd728b..9297e8bb 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu index c11096ac..8364b1e7 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu index 2fbcd44e..1c6ed7ef 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu index 7b65a9c9..3c87573b 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu index 6fb3cf64..49fae856 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu index e696b2f2..c5af1cf6 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu index bb3b744d..b0d6c992 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu index 5f3accc3..c97aa33f 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023, Tri Dao. +// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h index 1bf77f81..b6b26d52 100644 --- a/candle-flash-attn/kernels/flash_fwd_kernel.h +++ b/candle-flash-attn/kernels/flash_fwd_kernel.h @@ -4,6 +4,8 @@ #pragma once +// #include "philox_unpack.cuh" // For at::cuda::philox::unpack + #include <cute/tensor.hpp> #include <cutlass/cutlass.h> @@ -22,14 +24,6 @@ namespace flash { using namespace cute; -template <typename Engine, typename Layout> -__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){ - #pragma unroll - for (int i = 0; i < size(tensor); ++i) { - tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); - } -} - //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN> @@ -328,7 +322,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } mask.template apply_mask<Is_causal, Is_even_MN>( @@ -394,7 +388,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } flash::cp_async_wait<0>(); @@ -691,7 +685,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. - const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), Shape<Int<kBlockN>, Int<kHeadDim / 2>>{}, make_stride(params.rotary_dim / 2, _1{})); @@ -712,9 +706,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + const index_t row_offset_knew = bidb * params.knew_batch_stride + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + const index_t row_offset_vnew = bidb * params.vnew_batch_stride + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. @@ -792,7 +788,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { - const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), @@ -886,7 +882,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } @@ -961,7 +957,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ - apply_softcap(acc_s, params.softcap); + flash::apply_softcap(acc_s, params.softcap); } flash::cp_async_wait<0>(); @@ -1226,7 +1222,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { constexpr int kBlockN = kNThreads / kBlockM; using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>; using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 9e5449d7..bb581eb3 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -3,11 +3,11 @@ ******************************************************************************/ #pragma once - -// #include <ATen/cuda/CUDAContext.h> +// #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include "error.h" #include "static_switch.h" +#include "hardware_info.h" #include "flash.h" #include "flash_fwd_kernel.h" @@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>; + auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>; // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; @@ -205,7 +205,8 @@ inline bool cuda_is_sm8x() { template<typename T, bool Is_causal> void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; - bool is_sm8x = cuda_is_sm8x(); + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { @@ -228,7 +229,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { template<typename T, bool Is_causal> void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; - bool is_sm8x = cuda_is_sm8x(); + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if constexpr(!Is_dropout) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -262,7 +264,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { template<typename T, bool Is_causal> void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 160; - bool is_sm8x = cuda_is_sm8x(); + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), diff --git a/candle-flash-attn/kernels/hardware_info.h b/candle-flash-attn/kernels/hardware_info.h new file mode 100644 index 00000000..d5c48d35 --- /dev/null +++ b/candle-flash-attn/kernels/hardware_info.h @@ -0,0 +1,42 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include <tuple> +#include <cstdio> + +#if !defined(__CUDACC_RTC__) +#include "cuda_runtime.h" +#endif + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) + + +inline int get_current_device() { + int device; + CHECK_CUDA(cudaGetDevice(&device)); + return device; +} + +inline std::tuple<int, int> get_compute_capability(int device) { + int capability_major, capability_minor; + CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device)); + return {capability_major, capability_minor}; +} + +inline int get_num_sm(int device) { + int multiprocessor_count; + CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + return multiprocessor_count; +} diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h index 5a7b7491..8c089748 100644 --- a/candle-flash-attn/kernels/kernel_traits.h +++ b/candle-flash-attn/kernels/kernel_traits.h @@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape<Int<kBlockM>, Int<kHeadDim>>{})); - using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>; - using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>; + using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>; + using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); @@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base { using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, - DefaultCopy + AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{}, GmemLayoutAtom{}, Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{}, GmemLayoutAtom{}, Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store @@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base { Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store using GmemLayoutAtomRotcossin = GmemLayoutAtom; @@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base { GmemLayoutAtomRotcossin{}, Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{}, GmemLayoutAtomRotcossin{}, Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load }; @@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base { composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); - using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; + using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>; using SmemLayoutQdOtransposed = decltype( composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{}))); @@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); - using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>; + using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>; using SmemLayoutAtomdQ = decltype( composition(Swizzle<kSwizzle, 3, 3>{}, @@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int<kBlockM>{}, Int<kHeadDim>{}))); - using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>; + using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); @@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base { using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, - DefaultCopy + AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{}, GmemLayoutAtom{}, Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{}, GmemLayoutAtom{}, Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store using GmemTiledCopydKV = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{}, GmemLayoutAtom{}, Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{}, GmemLayoutAtom{}, Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< @@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base { Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, GmemLayoutAtomdQaccum{}, Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, Layout<Shape <_8, _32>, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h index 708aeddf..b7408ec4 100644 --- a/candle-flash-attn/kernels/utils.h +++ b/candle-flash-attn/kernels/utils.h @@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S //////////////////////////////////////////////////////////////////////////////////////////////////// +template <typename Engine, typename Layout> +__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> +__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(src_tensor); ++i) { + dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash |