summaryrefslogtreecommitdiff
path: root/candle-flash-attn/build.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/build.rs')
-rw-r--r--candle-flash-attn/build.rs18
1 files changed, 17 insertions, 1 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index 4002770b..53fec5de 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -4,7 +4,7 @@
use anyhow::{Context, Result};
use std::path::PathBuf;
-const KERNEL_FILES: [&str; 17] = [
+const KERNEL_FILES: [&str; 33] = [
"kernels/flash_api.cu",
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
@@ -22,6 +22,22 @@ const KERNEL_FILES: [&str; 17] = [
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim128_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim160_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim192_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim224_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim256_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim32_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim64_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim96_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim128_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim160_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim192_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim224_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim256_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim32_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim64_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim96_bf16_causal_sm80.cu",
];
fn main() -> Result<()> {