From f052ba76cbf88f8e4f9fe38e76f7a2673da6b5f2 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 26 Jul 2023 15:11:45 +0100 Subject: Lining up the flash attn version with the non-flash one. (#248) * Move the flash-attn function in the proper crate. * Causality tweak. --- candle-flash-attn/src/lib.rs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) (limited to 'candle-flash-attn') diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index b159aee2..c2dec7d7 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -3,7 +3,7 @@ mod ffi; use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, Error, Layout, Result, Shape}; +use candle::{CpuStorage, Error, Layout, Result, Shape, Tensor}; use half::f16; pub struct FlashHdim32Sm80 { @@ -144,3 +144,20 @@ impl candle::CustomOp3 for FlashHdim32Sm80 { Ok((dst, out_shape)) } } + +pub fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + q.custom_op3( + k, + v, + FlashHdim32Sm80 { + softmax_scale, + causal, + }, + ) +} -- cgit v1.2.3