summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-flash-attn/build.rs18
-rw-r--r--candle-flash-attn/kernels/flash_api.cu26
-rw-r--r--candle-flash-attn/src/ffi.rs1
-rw-r--r--candle-flash-attn/src/lib.rs2
4 files changed, 25 insertions, 22 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index 773c5638..2a3b7eb1 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -6,7 +6,7 @@ use rayon::prelude::*;
use std::path::PathBuf;
use std::str::FromStr;
-const KERNEL_FILES: [&str; 9] = [
+const KERNEL_FILES: [&str; 17] = [
"flash_api.cu",
"flash_fwd_hdim128_fp16_sm80.cu",
"flash_fwd_hdim160_fp16_sm80.cu",
@@ -16,14 +16,14 @@ const KERNEL_FILES: [&str; 9] = [
"flash_fwd_hdim32_fp16_sm80.cu",
"flash_fwd_hdim64_fp16_sm80.cu",
"flash_fwd_hdim96_fp16_sm80.cu",
- // "flash_fwd_hdim128_bf16_sm80.cu",
- // "flash_fwd_hdim160_bf16_sm80.cu",
- // "flash_fwd_hdim192_bf16_sm80.cu",
- // "flash_fwd_hdim224_bf16_sm80.cu",
- // "flash_fwd_hdim256_bf16_sm80.cu",
- // "flash_fwd_hdim32_bf16_sm80.cu",
- // "flash_fwd_hdim64_bf16_sm80.cu",
- // "flash_fwd_hdim96_bf16_sm80.cu",
+ "flash_fwd_hdim128_bf16_sm80.cu",
+ "flash_fwd_hdim160_bf16_sm80.cu",
+ "flash_fwd_hdim192_bf16_sm80.cu",
+ "flash_fwd_hdim224_bf16_sm80.cu",
+ "flash_fwd_hdim256_bf16_sm80.cu",
+ "flash_fwd_hdim32_bf16_sm80.cu",
+ "flash_fwd_hdim64_bf16_sm80.cu",
+ "flash_fwd_hdim96_bf16_sm80.cu",
];
fn main() -> Result<()> {
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index d928bcb6..72991257 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -1,20 +1,19 @@
#include "flash_fwd_launch_template.h"
-// TODO: Switch back to handling bf16.
-void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
- FWD_HEADDIM_SWITCH(params.d, [&] {
- run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
- });
-}
-
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
-// FP16_SWITCH(!params.is_bf16, [&] {
-// FWD_HEADDIM_SWITCH(params.d, [&] {
-// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
-// });
+// FWD_HEADDIM_SWITCH(params.d, [&] {
+// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
// });
// }
+void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+ FP16_SWITCH(!params.is_bf16, [&] {
+ FWD_HEADDIM_SWITCH(params.d, [&] {
+ run_mha_fwd_<elem_type, kHeadDim>(params, stream);
+ });
+ });
+}
+
extern "C" void run_mha(
void *q_ptr,
void *k_ptr,
@@ -52,7 +51,8 @@ extern "C" void run_mha(
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,
- int is_causal
+ int is_causal,
+ int is_bf16
) {
Flash_fwd_params params;
// Reset the parameters
@@ -102,7 +102,7 @@ extern "C" void run_mha(
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
- params.is_bf16 = 0;
+ params.is_bf16 = is_bf16;
params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`.
diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs
index ae61c405..90f34e43 100644
--- a/candle-flash-attn/src/ffi.rs
+++ b/candle-flash-attn/src/ffi.rs
@@ -38,6 +38,7 @@ extern "C" {
seqlen_k_rounded: u32,
is_causal: c_int,
+ is_bf16: c_int,
);
}
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 3c5fd455..cdb4b083 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -146,6 +146,7 @@ impl candle::CustomOp3 for FlashAttn {
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
+ /* is_bf16 */ 0,
)
}
@@ -354,6 +355,7 @@ impl candle::CustomOp3 for FlashAttnVarLen {
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
+ /* is_bf16 */ 0,
)
}