diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-26 15:11:45 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-26 15:11:45 +0100 |
commit | f052ba76cbf88f8e4f9fe38e76f7a2673da6b5f2 (patch) | |
tree | 30f0634746b585cb69702e554f863e663f7dd5d9 /candle-flash-attn | |
parent | 46f2d9f0acc0602e7f21bb3c35a5d5472fa3a515 (diff) | |
download | candle-f052ba76cbf88f8e4f9fe38e76f7a2673da6b5f2.tar.gz candle-f052ba76cbf88f8e4f9fe38e76f7a2673da6b5f2.tar.bz2 candle-f052ba76cbf88f8e4f9fe38e76f7a2673da6b5f2.zip |
Lining up the flash attn version with the non-flash one. (#248)
* Move the flash-attn function in the proper crate.
* Causality tweak.
Diffstat (limited to 'candle-flash-attn')
-rw-r--r-- | candle-flash-attn/src/lib.rs | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index b159aee2..c2dec7d7 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -3,7 +3,7 @@ mod ffi; use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, Error, Layout, Result, Shape}; +use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor}; use half::f16; pub struct FlashHdim32Sm80 { @@ -144,3 +144,20 @@ impl candle::CustomOp3 for FlashHdim32Sm80 { Ok((dst, out_shape)) } } + +pub fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + q.custom_op3( + k, + v, + FlashHdim32Sm80 { + softmax_scale, + causal, + }, + ) +} |