diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-17 11:12:05 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-17 11:12:05 +0100 |
commit | 03be33eea482accbcf4c547728c2db7e24b7ebb0 (patch) | |
tree | da5680d6d705d9346edbac9f2ce4a05779b86343 /candle-flash-attn/src | |
parent | d32e8199cd6c8381aa309528675d6d6a88c0f850 (diff) | |
download | candle-03be33eea482accbcf4c547728c2db7e24b7ebb0.tar.gz candle-03be33eea482accbcf4c547728c2db7e24b7ebb0.tar.bz2 candle-03be33eea482accbcf4c547728c2db7e24b7ebb0.zip |
Relax the requirements on CustomOp. (#486)
* Relax the requirements on CustomOp.
* Simplify the custom-ops when no backward is required.
Diffstat (limited to 'candle-flash-attn/src')
-rw-r--r-- | candle-flash-attn/src/lib.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 092743f1..3c5fd455 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -178,7 +178,7 @@ pub fn flash_attn( softmax_scale, causal, }; - q.custom_op3(k, v, op) + q.apply_op3(k, v, op) } struct FlashAttnVarLen { @@ -402,5 +402,5 @@ pub fn flash_attn_varlen( seqlens_q: seqlens_q.clone(), seqlens_k: seqlens_k.clone(), }; - q.custom_op3(k, v, op) + q.apply_op3(k, v, op) } |