summaryrefslogtreecommitdiff
path: root/candle-flash-attn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-26 15:11:45 +0100
committerGitHub <noreply@github.com>2023-07-26 15:11:45 +0100
commitf052ba76cbf88f8e4f9fe38e76f7a2673da6b5f2 (patch)
tree30f0634746b585cb69702e554f863e663f7dd5d9 /candle-flash-attn
parent46f2d9f0acc0602e7f21bb3c35a5d5472fa3a515 (diff)
downloadcandle-f052ba76cbf88f8e4f9fe38e76f7a2673da6b5f2.tar.gz
candle-f052ba76cbf88f8e4f9fe38e76f7a2673da6b5f2.tar.bz2
candle-f052ba76cbf88f8e4f9fe38e76f7a2673da6b5f2.zip
Lining up the flash attn version with the non-flash one. (#248)
* Move the flash-attn function in the proper crate. * Causality tweak.
Diffstat (limited to 'candle-flash-attn')
-rw-r--r--candle-flash-attn/src/lib.rs19
1 files changed, 18 insertions, 1 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index b159aee2..c2dec7d7 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -3,7 +3,7 @@ mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
-use candle::{CpuStorage, Error, Layout, Result, Shape};
+use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor};
use half::f16;
pub struct FlashHdim32Sm80 {
@@ -144,3 +144,20 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
Ok((dst, out_shape))
}
}
+
+pub fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ q.custom_op3(
+ k,
+ v,
+ FlashHdim32Sm80 {
+ softmax_scale,
+ causal,
+ },
+ )
+}