diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-31 09:45:39 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-31 09:45:39 +0100 |
commit | 0ace420e66b86fd6146a02fe9b8aca6a41c0eabd (patch) | |
tree | 76c842b732503b41f86b23b3b9aea98f959bc2b4 /candle-flash-attn | |
parent | a8d8f9f20601b30124d1c5096e3ad276afc99bf8 (diff) | |
download | candle-0ace420e66b86fd6146a02fe9b8aca6a41c0eabd.tar.gz candle-0ace420e66b86fd6146a02fe9b8aca6a41c0eabd.tar.bz2 candle-0ace420e66b86fd6146a02fe9b8aca6a41c0eabd.zip |
Flash attention without padding (varlen). (#281)
* Expose the seqlen variable for flash-attn without padding.
* Fix the batched call.
* Adapt for the varlen variant.
* No need to set the batch strides when in varlen mode.
* Add a test (disabled at the moment).
* Get the test to work properly.
Diffstat (limited to 'candle-flash-attn')
-rw-r--r-- | candle-flash-attn/kernels/flash_api.cu | 9 | ||||
-rw-r--r-- | candle-flash-attn/src/ffi.rs | 2 | ||||
-rw-r--r-- | candle-flash-attn/src/lib.rs | 231 | ||||
-rw-r--r-- | candle-flash-attn/tests/flash_attn_tests.rs | 45 |
4 files changed, 283 insertions, 4 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 323aeaad..d928bcb6 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -22,6 +22,9 @@ extern "C" void run_mha( void *o_ptr, void *softmax_lse_ptr, + int32_t *cu_seqlens_q_ptr, + int32_t *cu_seqlens_k_ptr, + uint32_t q_batch_stride, uint32_t k_batch_stride, uint32_t v_batch_stride, @@ -100,9 +103,9 @@ extern "C" void run_mha( params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; params.is_bf16 = 0; - params.cu_seqlens_q = nullptr; - params.cu_seqlens_k = nullptr; - params.p_ptr = nullptr; + params.cu_seqlens_q = cu_seqlens_q_ptr; + params.cu_seqlens_k = cu_seqlens_k_ptr; + params.p_ptr = nullptr; // used for `return_softmax`. cudaStream_t stream = 0; // Use the default stream. run_mha_fwd(params, stream); diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index f4415539..ae61c405 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -7,6 +7,8 @@ extern "C" { v_ptr: *const c_void, o_ptr: *const c_void, softmax_lse_ptr: *const c_void, + cu_seqlens_q_ptr: *const i32, + cu_seqlens_k_ptr: *const i32, q_batch_stride: u32, k_batch_stride: u32, diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index efdefee9..99b05229 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -49,6 +49,9 @@ impl candle::CustomOp3 for FlashAttn { let q = q.as_cuda_slice::<f16>()?; let k = k.as_cuda_slice::<f16>()?; let v = v.as_cuda_slice::<f16>()?; + let q = q.slice(q_l.start_offset()..); + let k = k.slice(k_l.start_offset()..); + let v = v.slice(v_l.start_offset()..); let q_stride = q_l.stride(); let k_stride = k_l.stride(); @@ -118,6 +121,8 @@ impl candle::CustomOp3 for FlashAttn { v_ptr, dst_ptr, softmax_lse_ptr, + /* cu_seqlens_q_ptr */ std::ptr::null(), + /* cu_seqlens_k_ptr */ std::ptr::null(), /* q_batch_stride */ q_stride[0] as u32, /* k_batch_stride */ k_stride[0] as u32, /* v_batch_stride */ v_stride[0] as u32, @@ -149,7 +154,7 @@ impl candle::CustomOp3 for FlashAttn { } } -/// Flash-attention v2 layer using flash-attention. +/// Flash-attention v2 layer. /// /// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. /// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads @@ -175,3 +180,227 @@ pub fn flash_attn( }; q.custom_op3(k, v, op) } + +struct FlashAttnVarLen { + softmax_scale: f32, + causal: bool, + max_seqlen_q: usize, + max_seqlen_k: usize, + seqlens_q: Tensor, + seqlens_k: Tensor, +} + +impl candle::CustomOp3 for FlashAttnVarLen { + fn name(&self) -> &'static str { + "flash-hdim32-sm80" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &candle::CudaStorage, + q_l: &Layout, + k: &candle::CudaStorage, + k_l: &Layout, + v: &candle::CudaStorage, + v_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + // https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327 + let dev = q.device(); + let out_shape = q_l.shape().clone(); + let out_l = Layout::contiguous(&out_shape); + + let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout(); + let seqlens_q = match &*seqlens_q { + candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"), + candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32! + }; + let seqlens_q = match seqlens_q_layout.contiguous_offsets() { + Some((o1, o2)) => seqlens_q.slice(o1..o2), + None => candle::bail!("seqlens_q has to be contiguous"), + }; + + let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout(); + let seqlens_k = match &*seqlens_k { + candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"), + candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32! + }; + let seqlens_k = match seqlens_k_layout.contiguous_offsets() { + Some((o1, o2)) => seqlens_k.slice(o1..o2), + None => candle::bail!("seqlens_k has to be contiguous"), + }; + + let q = q.as_cuda_slice::<f16>()?; + let k = k.as_cuda_slice::<f16>()?; + let v = v.as_cuda_slice::<f16>()?; + let q = q.slice(q_l.start_offset()..); + let k = k.slice(k_l.start_offset()..); + let v = v.slice(v_l.start_offset()..); + + let q_stride = q_l.stride(); + let k_stride = k_l.stride(); + let v_stride = v_l.stride(); + let o_stride = out_l.stride(); + + let q_rank = q_stride.len(); + let k_rank = k_stride.len(); + let v_rank = v_stride.len(); + let o_rank = o_stride.len(); + + if q_rank != 3 || k_rank != 3 || v_rank != 3 { + candle::bail!( + "flash-attn-varlen expects input tensors of rank 3 (q: {q_rank}, k: {k_rank}, v: {v_rank}" + ) + } + if q_stride[q_rank - 1] != 1 { + candle::bail!("the last dim of q must be contiguous {q_stride:?}") + } + if k_stride[k_rank - 1] != 1 { + candle::bail!("the last dim of k must be contiguous {k_stride:?}") + } + if v_stride[v_rank - 1] != 1 { + candle::bail!("the last dim of v must be contiguous {v_stride:?}") + } + + let (_total_q, num_heads, head_size_og) = q_l.shape().dims3()?; + let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?; + let expected_kv = (total_k, num_heads_k, head_size_og); + if expected_kv != k_l.shape().dims3()? { + candle::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape()) + } + if expected_kv != v_l.shape().dims3()? { + candle::bail!("shape mismatch q {:?} and v {:?}", q_l.shape(), v_l.shape()) + } + if head_size_og > 256 { + candle::bail!("only supports head dimension at most 256 (got {head_size_og})") + } + if head_size_og % 8 != 0 { + // TODO: Handle head sizes that are not a multiple of 8 via some padding. + candle::bail!("only supports head sizes that are a multiple of 8 (got {head_size_og})") + } + if num_heads % num_heads_k != 0 { + candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") + } + + let nseqlens_q = seqlens_q_layout.shape().dims1()?; + if nseqlens_q < 2 { + candle::bail!("seqlens_q should have a len >= 2 {nseqlens_q}") + } + let nseqlens_k = seqlens_k_layout.shape().dims1()?; + if nseqlens_k != nseqlens_q { + candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") + } + let batch_size = nseqlens_q - 1; + let head_size = round_multiple(head_size_og, 8); + let head_size_rounded = round_multiple(head_size, 32); + let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128); + let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128); + + let elem_count = out_shape.elem_count(); + let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?; + let softmax_lse = dev + .alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q) + .w()?; + + let causal = if self.causal { 1 } else { 0 }; + + unsafe { + let q_ptr = *q.device_ptr() as *const core::ffi::c_void; + let k_ptr = *k.device_ptr() as *const core::ffi::c_void; + let v_ptr = *v.device_ptr() as *const core::ffi::c_void; + let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; + let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void; + let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int; + let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int; + ffi::run_mha( + q_ptr, + k_ptr, + v_ptr, + dst_ptr, + softmax_lse_ptr, + /* cu_seqlens_q_ptr */ seqlens_q_ptr, + /* cu_seqlens_k_ptr */ seqlens_k_ptr, + /* q_batch_stride */ 0, + /* k_batch_stride */ 0, + /* v_batch_stride */ 0, + /* o_batch_stride */ 0, + /* q_row_stride */ q_stride[q_rank - 3] as u32, + /* k_row_stride */ k_stride[k_rank - 3] as u32, + /* v_row_stride */ v_stride[v_rank - 3] as u32, + /* o_row_stride */ o_stride[o_rank - 3] as u32, + /* q_head_stride */ q_stride[q_rank - 2] as u32, + /* k_head_stride */ k_stride[k_rank - 2] as u32, + /* v_head_stride */ v_stride[v_rank - 2] as u32, + /* o_head_stride */ o_stride[o_rank - 2] as u32, + /* b */ batch_size as u32, + /* h */ num_heads as u32, + /* h_k */ num_heads_k as u32, + /* d */ head_size as u32, + /* d_rounded */ head_size_rounded as u32, + /* softmax_scale*/ self.softmax_scale, + /* seqlen_q */ self.max_seqlen_q as u32, + /* seqlen_k */ self.max_seqlen_k as u32, + /* seqlen_q_rounded */ seqlen_q_rounded as u32, + /* seqlen_k_rounded */ seqlen_k_rounded as u32, + /* is_causal */ causal, + ) + } + + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone()); + Ok((dst, out_shape)) + } +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +pub fn flash_attn_varlen( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + let op = FlashAttnVarLen { + softmax_scale, + causal, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + }; + q.custom_op3(k, v, op) +} diff --git a/candle-flash-attn/tests/flash_attn_tests.rs b/candle-flash-attn/tests/flash_attn_tests.rs index 43cb324f..250added 100644 --- a/candle-flash-attn/tests/flash_attn_tests.rs +++ b/candle-flash-attn/tests/flash_attn_tests.rs @@ -88,3 +88,48 @@ fn flash_attn_acausal() -> Result<()> { assert!(diff.to_vec0::<f32>()?.abs() < 1e-5); Ok(()) } + +#[test] +fn flash_attn_varlen() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?; + let seqlens_k = Tensor::new(&[0u32, 2u32], &device)?; + + let ys = { + let q = q.transpose(0, 1)?; + let k = k.transpose(0, 1)?; + let v = v.transpose(0, 1)?; + candle_flash_attn::flash_attn_varlen( + &q, &k, &v, &seqlens_q, &seqlens_k, 32, 32, 0.5, false, + )? + .transpose(0, 1)? + }; + let ys = ys.to_dtype(DType::F32)?; + + assert_eq!(ys.dims(), &[3, 2, 8]); + assert_eq!( + to_vec3_round(ys, 4)?, + &[ + [ + [0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238], + [0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322] + ], + [ + [0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605], + [0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684] + ], + [ + [0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955], + [0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023] + ] + ] + ); + Ok(()) +} |