diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-26 14:16:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-26 14:16:37 +0100 |
commit | 2ce5f12513d0dafb04c7e345da9d4fba566cfa16 (patch) | |
tree | d8370aa035f667905e6f033e99e08fd93e677041 /candle-flash-attn | |
parent | fa2b64d678ca83e2fbc3dabdecffbc778d5b067d (diff) | |
download | candle-2ce5f12513d0dafb04c7e345da9d4fba566cfa16.tar.gz candle-2ce5f12513d0dafb04c7e345da9d4fba566cfa16.tar.bz2 candle-2ce5f12513d0dafb04c7e345da9d4fba566cfa16.zip |
Again set a few extra params in flash-attn. (#245)
* Again set a few extra params.
* Use the appropriate kernel sizes.
* Add all the kernel sizes.
* Parallel compiling.
* Reduce the amount of parallelism.
* Add the missing kernel.
* Fix a typo.
* Remove bf16 support for now.
Diffstat (limited to 'candle-flash-attn')
20 files changed, 471 insertions, 115 deletions
diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 25201a0e..9d21cf4a 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -16,3 +16,5 @@ half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] anyhow = { version = "1", features = ["backtrace"] } +num_cpus = "1.15.0" +rayon = "1.7.0" diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 05affcbe..7a4588a4 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -2,11 +2,45 @@ // The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment // variable in order to cache the compiled artifacts and avoid recompiling too often. use anyhow::{Context, Result}; +use rayon::prelude::*; use std::path::PathBuf; +use std::str::FromStr; + +const KERNEL_FILES: [&'static str; 9] = [ + "flash_api.cu", + "flash_fwd_hdim128_fp16_sm80.cu", + "flash_fwd_hdim160_fp16_sm80.cu", + "flash_fwd_hdim192_fp16_sm80.cu", + "flash_fwd_hdim224_fp16_sm80.cu", + "flash_fwd_hdim256_fp16_sm80.cu", + "flash_fwd_hdim32_fp16_sm80.cu", + "flash_fwd_hdim64_fp16_sm80.cu", + "flash_fwd_hdim96_fp16_sm80.cu", + // "flash_fwd_hdim128_bf16_sm80.cu", + // "flash_fwd_hdim160_bf16_sm80.cu", + // "flash_fwd_hdim192_bf16_sm80.cu", + // "flash_fwd_hdim224_bf16_sm80.cu", + // "flash_fwd_hdim256_bf16_sm80.cu", + // "flash_fwd_hdim32_bf16_sm80.cu", + // "flash_fwd_hdim64_bf16_sm80.cu", + // "flash_fwd_hdim96_bf16_sm80.cu", +]; fn main() -> Result<()> { + let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else( + |_| num_cpus::get_physical(), + |s| usize::from_str(&s).unwrap(), + ); + + rayon::ThreadPoolBuilder::new() + .num_threads(num_cpus) + .build_global() + .unwrap(); + println!("cargo:rerun-if-changed=build.rs"); - println!("cargo:rerun-if-changed=kernels/flash_fwd_hdim32_fp16_sm80.cu"); + for kernel_file in KERNEL_FILES.iter() { + println!("cargo:rerun-if-changed=kernels/{kernel_file}"); + } println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h"); println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h"); println!("cargo:rerun-if-changed=kernels/flash.h"); @@ -16,42 +50,74 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=kernels/kernel_traits.h"); println!("cargo:rerun-if-changed=kernels/block_info.h"); println!("cargo:rerun-if-changed=kernels/static_switch.h"); - + let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") { - Err(_) => std::env::var("OUT_DIR").context("OUT_DIR not set")?, - Ok(build_dir) => build_dir, + Err(_) => out_dir.clone(), + Ok(build_dir) => PathBuf::from(build_dir), }; - let build_dir = PathBuf::from(build_dir); set_cuda_include_dir()?; let compute_cap = compute_cap()?; - let mut command = std::process::Command::new("nvcc"); let out_file = build_dir.join("libflashattention.a"); - let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu"); + let kernel_dir = PathBuf::from("kernels"); + let cu_files: Vec<_> = KERNEL_FILES + .iter() + .map(|f| { + let mut obj_file = out_dir.join(f); + obj_file.set_extension("o"); + (kernel_dir.join(f), obj_file) + }) + .collect(); let should_compile = if out_file.exists() { - let out_modified = out_file.metadata()?.modified()?; - let in_modified = cu_file.metadata()?.modified()?; - in_modified.duration_since(out_modified).is_ok() + cu_files.iter().any(|(cu_file, _)| { + let out_modified = out_file.metadata().unwrap().modified().unwrap(); + let in_modified = cu_file.metadata().unwrap().modified().unwrap(); + in_modified.duration_since(out_modified).is_ok() + }) } else { true }; if should_compile { + cu_files + .par_iter() + .map(|(cu_file, obj_file)| { + let mut command = std::process::Command::new("nvcc"); + command + .arg(format!("--gpu-architecture=sm_{compute_cap}")) + .arg("-c") + .args(["-o", obj_file.to_str().unwrap()]) + .args(["--default-stream", "per-thread"]) + .arg("-Icutlass/include") + .arg("--expt-relaxed-constexpr") + .arg(cu_file); + let output = command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + if !output.status.success() { + anyhow::bail!( + "nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + Ok(()) + }) + .collect::<Result<()>>()?; + let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>(); + let mut command = std::process::Command::new("nvcc"); command - .arg(format!("--gpu-architecture=sm_{compute_cap}")) .arg("--lib") .args(["-o", out_file.to_str().unwrap()]) - .args(["--default-stream", "per-thread"]) - .arg("-Icutlass/include") - .arg("--expt-relaxed-constexpr") - .arg(cu_file); + .args(obj_files); let output = command .spawn() .context("failed spawning nvcc")? .wait_with_output()?; if !output.status.success() { anyhow::bail!( - "nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + "nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", String::from_utf8_lossy(&output.stdout), String::from_utf8_lossy(&output.stderr) ) diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu new file mode 100644 index 00000000..323aeaad --- /dev/null +++ b/candle-flash-attn/kernels/flash_api.cu @@ -0,0 +1,109 @@ +#include "flash_fwd_launch_template.h" + +// TODO: Switch back to handling bf16. +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + FWD_HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream); + }); +} + +// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { +// FP16_SWITCH(!params.is_bf16, [&] { +// FWD_HEADDIM_SWITCH(params.d, [&] { +// run_mha_fwd_<elem_type, kHeadDim>(params, stream); +// }); +// }); +// } + +extern "C" void run_mha( + void *q_ptr, + void *k_ptr, + void *v_ptr, + void *o_ptr, + void *softmax_lse_ptr, + + uint32_t q_batch_stride, + uint32_t k_batch_stride, + uint32_t v_batch_stride, + uint32_t o_batch_stride, + + uint32_t q_row_stride, + uint32_t k_row_stride, + uint32_t v_row_stride, + uint32_t o_row_stride, + + uint32_t q_head_stride, + uint32_t k_head_stride, + uint32_t v_head_stride, + uint32_t o_head_stride, + + uint32_t b, + uint32_t h, + uint32_t h_k, + uint32_t d, + uint32_t d_rounded, + float softmax_scale, + + uint32_t seqlen_q, + uint32_t seqlen_k, + uint32_t seqlen_q_rounded, + uint32_t seqlen_k_rounded, + + int is_causal +) { + Flash_fwd_params params; + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + // Set the pointers and strides. + params.q_ptr = q_ptr; + params.k_ptr = k_ptr; + params.v_ptr = v_ptr; + params.o_ptr = o_ptr; + + params.softmax_lse_ptr = softmax_lse_ptr; + + // All stride are in elements, not bytes. + params.q_batch_stride = q_batch_stride; + params.k_batch_stride = k_batch_stride; + params.v_batch_stride = v_batch_stride; + params.o_batch_stride = o_batch_stride; + + params.q_row_stride = q_row_stride; + params.k_row_stride = k_row_stride; + params.v_row_stride = v_row_stride; + params.o_row_stride = o_row_stride; + params.q_head_stride = q_head_stride; + params.k_head_stride = k_head_stride; + params.v_head_stride = v_head_stride; + params.o_head_stride = o_head_stride; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + params.is_causal = is_causal; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + params.p_dropout = 1.; // probability to keep + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + params.is_bf16 = 0; + params.cu_seqlens_q = nullptr; + params.cu_seqlens_k = nullptr; + params.p_ptr = nullptr; + + cudaStream_t stream = 0; // Use the default stream. + run_mha_fwd(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 00000000..654400a7 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// if (params.p_dropout == 1.f) { +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream); +// } else { +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 00000000..5b7254a9 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,32 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// if (params.p_dropout == 1.f) { +// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream); +// // 1st ones are good for H100, A100 +// // 2nd one is good for A6000 bc we get slightly better occupancy +// } else { +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream); +// // 1st one is good for H100, A100, A6000 +// } +// } + +template<> +void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128<cutlass::half_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 00000000..6a9d60c3 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream); +// }); +// } +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 00000000..6c40a164 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream); +// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest. +// // For A100, H100, 1st is fastest. +// }); +// } +template<> +void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160<cutlass::half_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 00000000..d2f4cba7 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream); +// }); +// } +template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 00000000..2875c926 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream); +// // This one is slightly faster for causal? +// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream); +// }); +// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout +// // For A6000, 1st is faster when causal, 3rd is faster when not causal +// } +template<> +void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192<cutlass::half_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 00000000..982fe7ea --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 00000000..4c083f7b --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224<cutlass::half_t>(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 00000000..cb074a95 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 00000000..ddf5e132 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256<cutlass::half_t>(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 00000000..81e359e1 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu index d8f071ef..91e6331e 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -20,94 +20,4 @@ template<> void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32<cutlass::half_t>(params, stream); -} - - -extern "C" void run_mha( - void *q_ptr, - void *k_ptr, - void *v_ptr, - void *o_ptr, - void *softmax_lse_ptr, - - uint32_t q_batch_stride, - uint32_t k_batch_stride, - uint32_t v_batch_stride, - uint32_t o_batch_stride, - - uint32_t q_row_stride, - uint32_t k_row_stride, - uint32_t v_row_stride, - uint32_t o_row_stride, - - uint32_t q_head_stride, - uint32_t k_head_stride, - uint32_t v_head_stride, - uint32_t o_head_stride, - - uint32_t b, - uint32_t h, - uint32_t h_k, - uint32_t d, - uint32_t d_rounded, - float softmax_scale, - - uint32_t seqlen_q, - uint32_t seqlen_k, - uint32_t seqlen_q_rounded, - uint32_t seqlen_k_rounded, - - int is_causal -) { - Flash_fwd_params params; - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - // Set the pointers and strides. - params.q_ptr = q_ptr; - params.k_ptr = k_ptr; - params.v_ptr = v_ptr; - params.o_ptr = o_ptr; - - params.softmax_lse_ptr = softmax_lse_ptr; - - // All stride are in elements, not bytes. - params.q_batch_stride = q_batch_stride; - params.k_batch_stride = k_batch_stride; - params.v_batch_stride = v_batch_stride; - params.o_batch_stride = o_batch_stride; - - params.q_row_stride = q_row_stride; - params.k_row_stride = k_row_stride; - params.v_row_stride = v_row_stride; - params.o_row_stride = o_row_stride; - params.q_head_stride = q_head_stride; - params.k_head_stride = k_head_stride; - params.v_head_stride = v_head_stride; - params.o_head_stride = o_head_stride; - - // Set the dimensions. - params.b = b; - params.h = h; - params.h_k = h_k; - params.h_h_k_ratio = h / h_k; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = d; - params.d_rounded = d_rounded; - params.is_causal = is_causal; - - // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; - - params.p_dropout = 1.; // probability to keep - params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); - params.rp_dropout = 1.f / params.p_dropout; - params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - - cudaStream_t stream = 0; // Use the default stream. - run_mha_fwd_<cutlass::half_t, 32>(params, stream); -} +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 00000000..fffcbebb --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// if (params.p_dropout == 1.f) { +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream); +// } else { +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 00000000..01bd1716 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// if (params.p_dropout == 1.f) { +// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower +// // Using block size (64 x 256) is 27% slower for seqlen=2k +// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream); +// } else { +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64<cutlass::half_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 00000000..b0b27db5 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream); +// }); +// } +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 00000000..820b63cb --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream); +// // This 3rd one is good for H100, and A100, A6000 +// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream); +// // These two are always slower +// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream); +// }); +// } +template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96<cutlass::half_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 0bbb451d..b159aee2 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -118,14 +118,14 @@ impl candle::CustomOp3 for FlashHdim32Sm80 { /* k_batch_stride */ k_stride[0] as u32, /* v_batch_stride */ v_stride[0] as u32, /* o_batch_stride */ o_stride[0] as u32, - /* q_row_stride */ q_stride[q_rank - 3] as u32, - /* k_row_stride */ k_stride[k_rank - 3] as u32, - /* v_row_stride */ v_stride[v_rank - 3] as u32, - /* o_row_stride */ o_stride[o_rank - 3] as u32, - /* q_head_stride */ q_stride[q_rank - 2] as u32, - /* k_head_stride */ k_stride[k_rank - 2] as u32, - /* v_head_stride */ v_stride[v_rank - 2] as u32, - /* o_head_stride */ o_stride[o_rank - 2] as u32, + /* q_row_stride */ q_stride[q_rank - 3] as u32, + /* k_row_stride */ k_stride[k_rank - 3] as u32, + /* v_row_stride */ v_stride[v_rank - 3] as u32, + /* o_row_stride */ o_stride[o_rank - 3] as u32, + /* q_head_stride */ q_stride[q_rank - 2] as u32, + /* k_head_stride */ k_stride[k_rank - 2] as u32, + /* v_head_stride */ v_stride[v_rank - 2] as u32, + /* o_head_stride */ o_stride[o_rank - 2] as u32, /* b */ b_sz as u32, /* h */ num_heads as u32, /* h_k */ num_heads_k as u32, |