summaryrefslogtreecommitdiff
path: root/candle-flash-attn/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-17 11:12:05 +0100
committerGitHub <noreply@github.com>2023-08-17 11:12:05 +0100
commit03be33eea482accbcf4c547728c2db7e24b7ebb0 (patch)
treeda5680d6d705d9346edbac9f2ce4a05779b86343 /candle-flash-attn/src
parentd32e8199cd6c8381aa309528675d6d6a88c0f850 (diff)
downloadcandle-03be33eea482accbcf4c547728c2db7e24b7ebb0.tar.gz
candle-03be33eea482accbcf4c547728c2db7e24b7ebb0.tar.bz2
candle-03be33eea482accbcf4c547728c2db7e24b7ebb0.zip
Relax the requirements on CustomOp. (#486)
* Relax the requirements on CustomOp. * Simplify the custom-ops when no backward is required.
Diffstat (limited to 'candle-flash-attn/src')
-rw-r--r--candle-flash-attn/src/lib.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 092743f1..3c5fd455 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -178,7 +178,7 @@ pub fn flash_attn(
softmax_scale,
causal,
};
- q.custom_op3(k, v, op)
+ q.apply_op3(k, v, op)
}
struct FlashAttnVarLen {
@@ -402,5 +402,5 @@ pub fn flash_attn_varlen(
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
};
- q.custom_op3(k, v, op)
+ q.apply_op3(k, v, op)
}