summaryrefslogtreecommitdiff
path: root/candle-flash-attn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-04 17:35:43 +0200
committerGitHub <noreply@github.com>2023-09-04 16:35:43 +0100
commitf80fd44201a61833781131b4fdc83d2e58e1b559 (patch)
treeaa4645fbec3a6fd57f129dc1a5394317ebb46e02 /candle-flash-attn
parent0d00c06a83b98c55d146564b96d913de4cec71c7 (diff)
downloadcandle-f80fd44201a61833781131b4fdc83d2e58e1b559.tar.gz
candle-f80fd44201a61833781131b4fdc83d2e58e1b559.tar.bz2
candle-f80fd44201a61833781131b4fdc83d2e58e1b559.zip
BF16 support for flash-attn. (#737)
Diffstat (limited to 'candle-flash-attn')
-rw-r--r--candle-flash-attn/src/lib.rs122
1 files changed, 81 insertions, 41 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index cdb4b083..b610915b 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -4,7 +4,7 @@ use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
-use half::f16;
+use half::{bf16, f16};
pub struct FlashAttn {
pub softmax_scale: f32,
@@ -15,24 +15,10 @@ fn round_multiple(x: usize, m: usize) -> usize {
(x + m - 1) / m * m
}
-impl candle::CustomOp3 for FlashAttn {
- fn name(&self) -> &'static str {
- "flash-attn"
- }
-
- fn cpu_fwd(
- &self,
- _: &CpuStorage,
- _: &Layout,
- _: &CpuStorage,
- _: &Layout,
- _: &CpuStorage,
- _: &Layout,
- ) -> Result<(CpuStorage, Shape)> {
- candle::bail!("no cpu support for flash-attn")
- }
-
- fn cuda_fwd(
+impl FlashAttn {
+ fn cuda_fwd_t<
+ T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
+ >(
&self,
q: &candle::CudaStorage,
q_l: &Layout,
@@ -46,9 +32,9 @@ impl candle::CustomOp3 for FlashAttn {
let out_shape = q_l.shape().clone();
let out_l = Layout::contiguous(&out_shape);
- let q = q.as_cuda_slice::<f16>()?;
- let k = k.as_cuda_slice::<f16>()?;
- let v = v.as_cuda_slice::<f16>()?;
+ let q = q.as_cuda_slice::<T>()?;
+ let k = k.as_cuda_slice::<T>()?;
+ let v = v.as_cuda_slice::<T>()?;
let q = q.slice(q_l.start_offset()..);
let k = k.slice(k_l.start_offset()..);
let v = v.slice(v_l.start_offset()..);
@@ -104,7 +90,7 @@ impl candle::CustomOp3 for FlashAttn {
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
let elem_count = out_shape.elem_count();
- let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
+ let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let causal = if self.causal { 1 } else { 0 };
@@ -155,6 +141,40 @@ impl candle::CustomOp3 for FlashAttn {
}
}
+impl candle::CustomOp3 for FlashAttn {
+ fn name(&self) -> &'static str {
+ "flash-attn"
+ }
+
+ fn cpu_fwd(
+ &self,
+ _: &CpuStorage,
+ _: &Layout,
+ _: &CpuStorage,
+ _: &Layout,
+ _: &CpuStorage,
+ _: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ candle::bail!("no cpu support for flash-attn")
+ }
+
+ fn cuda_fwd(
+ &self,
+ q: &candle::CudaStorage,
+ q_l: &Layout,
+ k: &candle::CudaStorage,
+ k_l: &Layout,
+ v: &candle::CudaStorage,
+ v_l: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ match q.dtype() {
+ candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
+ candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
+ dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
+ }
+ }
+}
+
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
@@ -191,24 +211,10 @@ struct FlashAttnVarLen {
seqlens_k: Tensor,
}
-impl candle::CustomOp3 for FlashAttnVarLen {
- fn name(&self) -> &'static str {
- "flash-attn-varlen"
- }
-
- fn cpu_fwd(
- &self,
- _: &CpuStorage,
- _: &Layout,
- _: &CpuStorage,
- _: &Layout,
- _: &CpuStorage,
- _: &Layout,
- ) -> Result<(CpuStorage, Shape)> {
- candle::bail!("no cpu support for flash-attn")
- }
-
- fn cuda_fwd(
+impl FlashAttnVarLen {
+ fn cuda_fwd_t<
+ T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
+ >(
&self,
q: &candle::CudaStorage,
q_l: &Layout,
@@ -364,6 +370,40 @@ impl candle::CustomOp3 for FlashAttnVarLen {
}
}
+impl candle::CustomOp3 for FlashAttnVarLen {
+ fn name(&self) -> &'static str {
+ "flash-attn-varlen"
+ }
+
+ fn cpu_fwd(
+ &self,
+ _: &CpuStorage,
+ _: &Layout,
+ _: &CpuStorage,
+ _: &Layout,
+ _: &CpuStorage,
+ _: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ candle::bail!("no cpu support for flash-attn")
+ }
+
+ fn cuda_fwd(
+ &self,
+ q: &candle::CudaStorage,
+ q_l: &Layout,
+ k: &candle::CudaStorage,
+ k_l: &Layout,
+ v: &candle::CudaStorage,
+ v_l: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ match q.dtype() {
+ candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
+ candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
+ dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
+ }
+ }
+}
+
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///