summaryrefslogtreecommitdiff
path: root/candle-flash-attn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-18 19:18:59 +0200
committerGitHub <noreply@github.com>2024-05-18 19:18:59 +0200
commit7ebc3548e19fb8d40940fa2315b3f817058e7e96 (patch)
treee8097869ef0b6f8db1a9091ef2f7df078e013cc7 /candle-flash-attn
parenteefc1c77ef00b74e1f8c6ac4e217dfbdbd419eff (diff)
downloadcandle-7ebc3548e19fb8d40940fa2315b3f817058e7e96.tar.gz
candle-7ebc3548e19fb8d40940fa2315b3f817058e7e96.tar.bz2
candle-7ebc3548e19fb8d40940fa2315b3f817058e7e96.zip
Use flash-attn in gemma. (#2195)
* Use flash-attn in gemma. * Fix flash-attn for head dim 256.
Diffstat (limited to 'candle-flash-attn')
-rw-r--r--candle-flash-attn/kernels/flash_fwd_launch_template.h4
-rw-r--r--candle-flash-attn/src/lib.rs4
2 files changed, 7 insertions, 1 deletions
diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h
index 66ab6206..002dd8ec 100644
--- a/candle-flash-attn/kernels/flash_fwd_launch_template.h
+++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h
@@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
+ if (smem_size >= 48 * 1024) {
+ cudaFuncSetAttribute(
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+ }
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 21a06b5e..f171a986 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -139,7 +139,9 @@ impl FlashAttn {
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
- let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
+ let softmax_lse = dev
+ .alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
+ .w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };