diff options
Diffstat (limited to 'candle-flash-attn/build.rs')
-rw-r--r-- | candle-flash-attn/build.rs | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 4002770b..53fec5de 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -4,7 +4,7 @@ use anyhow::{Context, Result}; use std::path::PathBuf; -const KERNEL_FILES: [&str; 17] = [ +const KERNEL_FILES: [&str; 33] = [ "kernels/flash_api.cu", "kernels/flash_fwd_hdim128_fp16_sm80.cu", "kernels/flash_fwd_hdim160_fp16_sm80.cu", @@ -22,6 +22,22 @@ const KERNEL_FILES: [&str; 17] = [ "kernels/flash_fwd_hdim32_bf16_sm80.cu", "kernels/flash_fwd_hdim64_bf16_sm80.cu", "kernels/flash_fwd_hdim96_bf16_sm80.cu", + "kernels/flash_fwd_hdim128_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim160_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim192_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim224_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim256_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim32_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim64_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim96_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim128_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim160_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim192_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim224_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim256_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim32_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim64_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim96_bf16_causal_sm80.cu", ]; fn main() -> Result<()> { |