From f80fd44201a61833781131b4fdc83d2e58e1b559 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 4 Sep 2023 17:35:43 +0200 Subject: BF16 support for flash-attn. (#737) --- candle-flash-attn/src/lib.rs | 122 ++++++++++++++++++++++++++++--------------- 1 file changed, 81 insertions(+), 41 deletions(-) (limited to 'candle-flash-attn') diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index cdb4b083..b610915b 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -4,7 +4,7 @@ use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::WrapErr; use candle::{CpuStorage, Layout, Result, Shape, Tensor}; -use half::f16; +use half::{bf16, f16}; pub struct FlashAttn { pub softmax_scale: f32, @@ -15,24 +15,10 @@ fn round_multiple(x: usize, m: usize) -> usize { (x + m - 1) / m * m } -impl candle::CustomOp3 for FlashAttn { - fn name(&self) -> &'static str { - "flash-attn" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for flash-attn") - } - - fn cuda_fwd( +impl FlashAttn { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( &self, q: &candle::CudaStorage, q_l: &Layout, @@ -46,9 +32,9 @@ impl candle::CustomOp3 for FlashAttn { let out_shape = q_l.shape().clone(); let out_l = Layout::contiguous(&out_shape); - let q = q.as_cuda_slice::()?; - let k = k.as_cuda_slice::()?; - let v = v.as_cuda_slice::()?; + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; let q = q.slice(q_l.start_offset()..); let k = k.slice(k_l.start_offset()..); let v = v.slice(v_l.start_offset()..); @@ -104,7 +90,7 @@ impl candle::CustomOp3 for FlashAttn { let seqlen_k_rounded = round_multiple(seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let dst = unsafe { dev.alloc::(elem_count) }.w()?; let softmax_lse = dev.alloc_zeros::(b_sz * num_heads * seqlen_q).w()?; let causal = if self.causal { 1 } else { 0 }; @@ -155,6 +141,40 @@ impl candle::CustomOp3 for FlashAttn { } } +impl candle::CustomOp3 for FlashAttn { + fn name(&self) -> &'static str { + "flash-attn" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match q.dtype() { + candle::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l), + candle::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l), + dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + /// Flash-attention v2 layer. /// /// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. @@ -191,24 +211,10 @@ struct FlashAttnVarLen { seqlens_k: Tensor, } -impl candle::CustomOp3 for FlashAttnVarLen { - fn name(&self) -> &'static str { - "flash-attn-varlen" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for flash-attn") - } - - fn cuda_fwd( +impl FlashAttnVarLen { + fn cuda_fwd_t< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( &self, q: &candle::CudaStorage, q_l: &Layout, @@ -364,6 +370,40 @@ impl candle::CustomOp3 for FlashAttnVarLen { } } +impl candle::CustomOp3 for FlashAttnVarLen { + fn name(&self) -> &'static str { + "flash-attn-varlen" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + match q.dtype() { + candle::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l), + candle::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l), + dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + #[allow(clippy::too_many_arguments)] /// Flash-attention v2 layer with variable-length batching. /// -- cgit v1.2.3