summaryrefslogtreecommitdiff
path: root/candle-flash-attn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-31 10:04:39 +0100
committerGitHub <noreply@github.com>2023-07-31 10:04:39 +0100
commit67834119fcfd961b4852458abe9426fbdfb2fd76 (patch)
treefc3a2e44841eb7fcdb11a6a6a77d4b6e20fbaac5 /candle-flash-attn
parent0ace420e66b86fd6146a02fe9b8aca6a41c0eabd (diff)
downloadcandle-67834119fcfd961b4852458abe9426fbdfb2fd76.tar.gz
candle-67834119fcfd961b4852458abe9426fbdfb2fd76.tar.bz2
candle-67834119fcfd961b4852458abe9426fbdfb2fd76.zip
Fix the flash-attention function names. (#282)
Diffstat (limited to 'candle-flash-attn')
-rw-r--r--candle-flash-attn/src/lib.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 99b05229..092743f1 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -17,7 +17,7 @@ fn round_multiple(x: usize, m: usize) -> usize {
impl candle::CustomOp3 for FlashAttn {
fn name(&self) -> &'static str {
- "flash-hdim32-sm80"
+ "flash-attn"
}
fn cpu_fwd(
@@ -192,7 +192,7 @@ struct FlashAttnVarLen {
impl candle::CustomOp3 for FlashAttnVarLen {
fn name(&self) -> &'static str {
- "flash-hdim32-sm80"
+ "flash-attn-varlen"
}
fn cpu_fwd(