summaryrefslogtreecommitdiff
path: root/candle-flash-attn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-26 14:16:37 +0100
committerGitHub <noreply@github.com>2023-07-26 14:16:37 +0100
commit2ce5f12513d0dafb04c7e345da9d4fba566cfa16 (patch)
treed8370aa035f667905e6f033e99e08fd93e677041 /candle-flash-attn
parentfa2b64d678ca83e2fbc3dabdecffbc778d5b067d (diff)
downloadcandle-2ce5f12513d0dafb04c7e345da9d4fba566cfa16.tar.gz
candle-2ce5f12513d0dafb04c7e345da9d4fba566cfa16.tar.bz2
candle-2ce5f12513d0dafb04c7e345da9d4fba566cfa16.zip
Again set a few extra params in flash-attn. (#245)
* Again set a few extra params. * Use the appropriate kernel sizes. * Add all the kernel sizes. * Parallel compiling. * Reduce the amount of parallelism. * Add the missing kernel. * Fix a typo. * Remove bf16 support for now.
Diffstat (limited to 'candle-flash-attn')
-rw-r--r--candle-flash-attn/Cargo.toml2
-rw-r--r--candle-flash-attn/build.rs98
-rw-r--r--candle-flash-attn/kernels/flash_api.cu109
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu19
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu32
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu17
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu27
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu16
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu27
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu9
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu9
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu9
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu9
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu92
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu19
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu26
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu17
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu23
-rw-r--r--candle-flash-attn/src/lib.rs16
20 files changed, 471 insertions, 115 deletions
diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml
index 25201a0e..9d21cf4a 100644
--- a/candle-flash-attn/Cargo.toml
+++ b/candle-flash-attn/Cargo.toml
@@ -16,3 +16,5 @@ half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
+num_cpus = "1.15.0"
+rayon = "1.7.0"
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index 05affcbe..7a4588a4 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -2,11 +2,45 @@
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result};
+use rayon::prelude::*;
use std::path::PathBuf;
+use std::str::FromStr;
+
+const KERNEL_FILES: [&'static str; 9] = [
+ "flash_api.cu",
+ "flash_fwd_hdim128_fp16_sm80.cu",
+ "flash_fwd_hdim160_fp16_sm80.cu",
+ "flash_fwd_hdim192_fp16_sm80.cu",
+ "flash_fwd_hdim224_fp16_sm80.cu",
+ "flash_fwd_hdim256_fp16_sm80.cu",
+ "flash_fwd_hdim32_fp16_sm80.cu",
+ "flash_fwd_hdim64_fp16_sm80.cu",
+ "flash_fwd_hdim96_fp16_sm80.cu",
+ // "flash_fwd_hdim128_bf16_sm80.cu",
+ // "flash_fwd_hdim160_bf16_sm80.cu",
+ // "flash_fwd_hdim192_bf16_sm80.cu",
+ // "flash_fwd_hdim224_bf16_sm80.cu",
+ // "flash_fwd_hdim256_bf16_sm80.cu",
+ // "flash_fwd_hdim32_bf16_sm80.cu",
+ // "flash_fwd_hdim64_bf16_sm80.cu",
+ // "flash_fwd_hdim96_bf16_sm80.cu",
+];
fn main() -> Result<()> {
+ let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
+ |_| num_cpus::get_physical(),
+ |s| usize::from_str(&s).unwrap(),
+ );
+
+ rayon::ThreadPoolBuilder::new()
+ .num_threads(num_cpus)
+ .build_global()
+ .unwrap();
+
println!("cargo:rerun-if-changed=build.rs");
- println!("cargo:rerun-if-changed=kernels/flash_fwd_hdim32_fp16_sm80.cu");
+ for kernel_file in KERNEL_FILES.iter() {
+ println!("cargo:rerun-if-changed=kernels/{kernel_file}");
+ }
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
println!("cargo:rerun-if-changed=kernels/flash.h");
@@ -16,42 +50,74 @@ fn main() -> Result<()> {
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
println!("cargo:rerun-if-changed=kernels/block_info.h");
println!("cargo:rerun-if-changed=kernels/static_switch.h");
-
+ let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
- Err(_) => std::env::var("OUT_DIR").context("OUT_DIR not set")?,
- Ok(build_dir) => build_dir,
+ Err(_) => out_dir.clone(),
+ Ok(build_dir) => PathBuf::from(build_dir),
};
- let build_dir = PathBuf::from(build_dir);
set_cuda_include_dir()?;
let compute_cap = compute_cap()?;
- let mut command = std::process::Command::new("nvcc");
let out_file = build_dir.join("libflashattention.a");
- let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu");
+ let kernel_dir = PathBuf::from("kernels");
+ let cu_files: Vec<_> = KERNEL_FILES
+ .iter()
+ .map(|f| {
+ let mut obj_file = out_dir.join(f);
+ obj_file.set_extension("o");
+ (kernel_dir.join(f), obj_file)
+ })
+ .collect();
let should_compile = if out_file.exists() {
- let out_modified = out_file.metadata()?.modified()?;
- let in_modified = cu_file.metadata()?.modified()?;
- in_modified.duration_since(out_modified).is_ok()
+ cu_files.iter().any(|(cu_file, _)| {
+ let out_modified = out_file.metadata().unwrap().modified().unwrap();
+ let in_modified = cu_file.metadata().unwrap().modified().unwrap();
+ in_modified.duration_since(out_modified).is_ok()
+ })
} else {
true
};
if should_compile {
+ cu_files
+ .par_iter()
+ .map(|(cu_file, obj_file)| {
+ let mut command = std::process::Command::new("nvcc");
+ command
+ .arg(format!("--gpu-architecture=sm_{compute_cap}"))
+ .arg("-c")
+ .args(["-o", obj_file.to_str().unwrap()])
+ .args(["--default-stream", "per-thread"])
+ .arg("-Icutlass/include")
+ .arg("--expt-relaxed-constexpr")
+ .arg(cu_file);
+ let output = command
+ .spawn()
+ .context("failed spawning nvcc")?
+ .wait_with_output()?;
+ if !output.status.success() {
+ anyhow::bail!(
+ "nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ String::from_utf8_lossy(&output.stdout),
+ String::from_utf8_lossy(&output.stderr)
+ )
+ }
+ Ok(())
+ })
+ .collect::<Result<()>>()?;
+ let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
+ let mut command = std::process::Command::new("nvcc");
command
- .arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--lib")
.args(["-o", out_file.to_str().unwrap()])
- .args(["--default-stream", "per-thread"])
- .arg("-Icutlass/include")
- .arg("--expt-relaxed-constexpr")
- .arg(cu_file);
+ .args(obj_files);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
- "nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ "nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
new file mode 100644
index 00000000..323aeaad
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -0,0 +1,109 @@
+#include "flash_fwd_launch_template.h"
+
+// TODO: Switch back to handling bf16.
+void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+ FWD_HEADDIM_SWITCH(params.d, [&] {
+ run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
+ });
+}
+
+// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+// FP16_SWITCH(!params.is_bf16, [&] {
+// FWD_HEADDIM_SWITCH(params.d, [&] {
+// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
+// });
+// });
+// }
+
+extern "C" void run_mha(
+ void *q_ptr,
+ void *k_ptr,
+ void *v_ptr,
+ void *o_ptr,
+ void *softmax_lse_ptr,
+
+ uint32_t q_batch_stride,
+ uint32_t k_batch_stride,
+ uint32_t v_batch_stride,
+ uint32_t o_batch_stride,
+
+ uint32_t q_row_stride,
+ uint32_t k_row_stride,
+ uint32_t v_row_stride,
+ uint32_t o_row_stride,
+
+ uint32_t q_head_stride,
+ uint32_t k_head_stride,
+ uint32_t v_head_stride,
+ uint32_t o_head_stride,
+
+ uint32_t b,
+ uint32_t h,
+ uint32_t h_k,
+ uint32_t d,
+ uint32_t d_rounded,
+ float softmax_scale,
+
+ uint32_t seqlen_q,
+ uint32_t seqlen_k,
+ uint32_t seqlen_q_rounded,
+ uint32_t seqlen_k_rounded,
+
+ int is_causal
+) {
+ Flash_fwd_params params;
+ // Reset the parameters
+ memset(&params, 0, sizeof(params));
+
+ // Set the pointers and strides.
+ params.q_ptr = q_ptr;
+ params.k_ptr = k_ptr;
+ params.v_ptr = v_ptr;
+ params.o_ptr = o_ptr;
+
+ params.softmax_lse_ptr = softmax_lse_ptr;
+
+ // All stride are in elements, not bytes.
+ params.q_batch_stride = q_batch_stride;
+ params.k_batch_stride = k_batch_stride;
+ params.v_batch_stride = v_batch_stride;
+ params.o_batch_stride = o_batch_stride;
+
+ params.q_row_stride = q_row_stride;
+ params.k_row_stride = k_row_stride;
+ params.v_row_stride = v_row_stride;
+ params.o_row_stride = o_row_stride;
+ params.q_head_stride = q_head_stride;
+ params.k_head_stride = k_head_stride;
+ params.v_head_stride = v_head_stride;
+ params.o_head_stride = o_head_stride;
+
+ // Set the dimensions.
+ params.b = b;
+ params.h = h;
+ params.h_k = h_k;
+ params.h_h_k_ratio = h / h_k;
+ params.seqlen_q = seqlen_q;
+ params.seqlen_k = seqlen_k;
+ params.seqlen_q_rounded = seqlen_q_rounded;
+ params.seqlen_k_rounded = seqlen_k_rounded;
+ params.d = d;
+ params.d_rounded = d_rounded;
+ params.is_causal = is_causal;
+
+ // Set the different scale values.
+ params.scale_softmax = softmax_scale;
+ params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+
+ params.p_dropout = 1.; // probability to keep
+ params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
+ params.rp_dropout = 1.f / params.p_dropout;
+ params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
+ params.is_bf16 = 0;
+ params.cu_seqlens_q = nullptr;
+ params.cu_seqlens_k = nullptr;
+ params.p_ptr = nullptr;
+
+ cudaStream_t stream = 0; // Use the default stream.
+ run_mha_fwd(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
new file mode 100644
index 00000000..654400a7
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
@@ -0,0 +1,19 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::bfloat16_t;
+// if (params.p_dropout == 1.f) {
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
+// } else {
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
+// }
+// }
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
new file mode 100644
index 00000000..5b7254a9
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
@@ -0,0 +1,32 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::half_t;
+// if (params.p_dropout == 1.f) {
+// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
+// // 1st ones are good for H100, A100
+// // 2nd one is good for A6000 bc we get slightly better occupancy
+// } else {
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
+// // 1st one is good for H100, A100, A6000
+// }
+// }
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
new file mode 100644
index 00000000..6a9d60c3
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
@@ -0,0 +1,17 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::bfloat16_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// });
+// }
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
new file mode 100644
index 00000000..6c40a164
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
@@ -0,0 +1,27 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::half_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
+// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
+// // For A100, H100, 1st is fastest.
+// });
+// }
+template<>
+void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
new file mode 100644
index 00000000..d2f4cba7
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::bfloat16_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// });
+// }
+template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
new file mode 100644
index 00000000..2875c926
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
@@ -0,0 +1,27 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::half_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// // This one is slightly faster for causal?
+// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
+// });
+// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
+// // For A6000, 1st is faster when causal, 3rd is faster when not causal
+// }
+template<>
+void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
new file mode 100644
index 00000000..982fe7ea
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
@@ -0,0 +1,9 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
new file mode 100644
index 00000000..4c083f7b
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
@@ -0,0 +1,9 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
new file mode 100644
index 00000000..cb074a95
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
@@ -0,0 +1,9 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
new file mode 100644
index 00000000..ddf5e132
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
@@ -0,0 +1,9 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
new file mode 100644
index 00000000..81e359e1
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
index d8f071ef..91e6331e 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
@@ -20,94 +20,4 @@
template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
-}
-
-
-extern "C" void run_mha(
- void *q_ptr,
- void *k_ptr,
- void *v_ptr,
- void *o_ptr,
- void *softmax_lse_ptr,
-
- uint32_t q_batch_stride,
- uint32_t k_batch_stride,
- uint32_t v_batch_stride,
- uint32_t o_batch_stride,
-
- uint32_t q_row_stride,
- uint32_t k_row_stride,
- uint32_t v_row_stride,
- uint32_t o_row_stride,
-
- uint32_t q_head_stride,
- uint32_t k_head_stride,
- uint32_t v_head_stride,
- uint32_t o_head_stride,
-
- uint32_t b,
- uint32_t h,
- uint32_t h_k,
- uint32_t d,
- uint32_t d_rounded,
- float softmax_scale,
-
- uint32_t seqlen_q,
- uint32_t seqlen_k,
- uint32_t seqlen_q_rounded,
- uint32_t seqlen_k_rounded,
-
- int is_causal
-) {
- Flash_fwd_params params;
- // Reset the parameters
- memset(&params, 0, sizeof(params));
-
- // Set the pointers and strides.
- params.q_ptr = q_ptr;
- params.k_ptr = k_ptr;
- params.v_ptr = v_ptr;
- params.o_ptr = o_ptr;
-
- params.softmax_lse_ptr = softmax_lse_ptr;
-
- // All stride are in elements, not bytes.
- params.q_batch_stride = q_batch_stride;
- params.k_batch_stride = k_batch_stride;
- params.v_batch_stride = v_batch_stride;
- params.o_batch_stride = o_batch_stride;
-
- params.q_row_stride = q_row_stride;
- params.k_row_stride = k_row_stride;
- params.v_row_stride = v_row_stride;
- params.o_row_stride = o_row_stride;
- params.q_head_stride = q_head_stride;
- params.k_head_stride = k_head_stride;
- params.v_head_stride = v_head_stride;
- params.o_head_stride = o_head_stride;
-
- // Set the dimensions.
- params.b = b;
- params.h = h;
- params.h_k = h_k;
- params.h_h_k_ratio = h / h_k;
- params.seqlen_q = seqlen_q;
- params.seqlen_k = seqlen_k;
- params.seqlen_q_rounded = seqlen_q_rounded;
- params.seqlen_k_rounded = seqlen_k_rounded;
- params.d = d;
- params.d_rounded = d_rounded;
- params.is_causal = is_causal;
-
- // Set the different scale values.
- params.scale_softmax = softmax_scale;
- params.scale_softmax_log2 = softmax_scale * M_LOG2E;
-
- params.p_dropout = 1.; // probability to keep
- params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
- params.rp_dropout = 1.f / params.p_dropout;
- params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
-
- cudaStream_t stream = 0; // Use the default stream.
- run_mha_fwd_<cutlass::half_t, 32>(params, stream);
-}
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
new file mode 100644
index 00000000..fffcbebb
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
@@ -0,0 +1,19 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::bfloat16_t;
+// if (params.p_dropout == 1.f) {
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
+// } else {
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
+// }
+// }
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
new file mode 100644
index 00000000..01bd1716
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
@@ -0,0 +1,26 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::half_t;
+// if (params.p_dropout == 1.f) {
+// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
+// // Using block size (64 x 256) is 27% slower for seqlen=2k
+// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
+// } else {
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
+// }
+// }
+template<>
+void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
new file mode 100644
index 00000000..b0b27db5
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
@@ -0,0 +1,17 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::bfloat16_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
+// });
+// }
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
new file mode 100644
index 00000000..820b63cb
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
@@ -0,0 +1,23 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::half_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
+// // This 3rd one is good for H100, and A100, A6000
+// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
+// // These two are always slower
+// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
+// });
+// }
+template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 0bbb451d..b159aee2 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -118,14 +118,14 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
/* k_batch_stride */ k_stride[0] as u32,
/* v_batch_stride */ v_stride[0] as u32,
/* o_batch_stride */ o_stride[0] as u32,
- /* q_row_stride */ q_stride[q_rank - 3] as u32,
- /* k_row_stride */ k_stride[k_rank - 3] as u32,
- /* v_row_stride */ v_stride[v_rank - 3] as u32,
- /* o_row_stride */ o_stride[o_rank - 3] as u32,
- /* q_head_stride */ q_stride[q_rank - 2] as u32,
- /* k_head_stride */ k_stride[k_rank - 2] as u32,
- /* v_head_stride */ v_stride[v_rank - 2] as u32,
- /* o_head_stride */ o_stride[o_rank - 2] as u32,
+ /* q_row_stride */ q_stride[q_rank - 3] as u32,
+ /* k_row_stride */ k_stride[k_rank - 3] as u32,
+ /* v_row_stride */ v_stride[v_rank - 3] as u32,
+ /* o_row_stride */ o_stride[o_rank - 3] as u32,
+ /* q_head_stride */ q_stride[q_rank - 2] as u32,
+ /* k_head_stride */ k_stride[k_rank - 2] as u32,
+ /* v_head_stride */ v_stride[v_rank - 2] as u32,
+ /* o_head_stride */ o_stride[o_rank - 2] as u32,
/* b */ b_sz as u32,
/* h */ num_heads as u32,
/* h_k */ num_heads_k as u32,