diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-04 17:35:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-04 16:35:43 +0100 |
commit | f80fd44201a61833781131b4fdc83d2e58e1b559 (patch) | |
tree | aa4645fbec3a6fd57f129dc1a5394317ebb46e02 /candle-flash-attn | |
parent | 0d00c06a83b98c55d146564b96d913de4cec71c7 (diff) | |
download | candle-f80fd44201a61833781131b4fdc83d2e58e1b559.tar.gz candle-f80fd44201a61833781131b4fdc83d2e58e1b559.tar.bz2 candle-f80fd44201a61833781131b4fdc83d2e58e1b559.zip |
BF16 support for flash-attn. (#737)
Diffstat (limited to 'candle-flash-attn')
-rw-r--r-- | candle-flash-attn/src/lib.rs | 122 |
1 files changed, 81 insertions, 41 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index cdb4b083..b610915b 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -4,7 +4,7 @@ use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; use candle::{CpuStorage, Layout, Result, Shape, Tensor}; -use half::f16; +use half::{bf16, f16}; pub struct FlashAttn { pub softmax_scale: f32, @@ -15,24 +15,10 @@ fn round_multiple(x: usize, m: usize) -> usize { (x + m - 1) / m * m } -impl candle::CustomOp3 for FlashAttn { - fn name(&self) -> &'static str { - "flash-attn" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for flash-attn") - } - - fn cuda_fwd( +impl FlashAttn { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( &self, q: &candle::CudaStorage, q_l: &Layout, @@ -46,9 +32,9 @@ impl candle::CustomOp3 for FlashAttn { let out_shape = q_l.shape().clone(); let out_l = Layout::contiguous(&out_shape); - let q = q.as_cuda_slice::<f16>()?; - let k = k.as_cuda_slice::<f16>()?; - let v = v.as_cuda_slice::<f16>()?; + let q = q.as_cuda_slice::<T>()?; + let k = k.as_cuda_slice::<T>()?; + let v = v.as_cuda_slice::<T>()?; let q = q.slice(q_l.start_offset()..); let k = k.slice(k_l.start_offset()..); let v = v.slice(v_l.start_offset()..); @@ -104,7 +90,7 @@ impl candle::CustomOp3 for FlashAttn { let seqlen_k_rounded = round_multiple(seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?; + let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?; let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?; let causal = if self.causal { 1 } else { 0 }; @@ -155,6 +141,40 @@ impl candle::CustomOp3 for FlashAttn { } } +impl candle::CustomOp3 for FlashAttn { + fn name(&self) -> &'static str { + "flash-attn" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match q.dtype() { + candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l), + candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l), + dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + /// Flash-attention v2 layer. /// /// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. @@ -191,24 +211,10 @@ struct FlashAttnVarLen { seqlens_k: Tensor, } -impl candle::CustomOp3 for FlashAttnVarLen { - fn name(&self) -> &'static str { - "flash-attn-varlen" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for flash-attn") - } - - fn cuda_fwd( +impl FlashAttnVarLen { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( &self, q: &candle::CudaStorage, q_l: &Layout, @@ -364,6 +370,40 @@ impl candle::CustomOp3 for FlashAttnVarLen { } } +impl candle::CustomOp3 for FlashAttnVarLen { + fn name(&self) -> &'static str { + "flash-attn-varlen" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match q.dtype() { + candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l), + candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l), + dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + #[allow(clippy::too_many_arguments)] /// Flash-attention v2 layer with variable-length batching. /// |