From 7ebc3548e19fb8d40940fa2315b3f817058e7e96 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 18 May 2024 19:18:59 +0200 Subject: Use flash-attn in gemma. (#2195) * Use flash-attn in gemma. * Fix flash-attn for head dim 256. --- candle-flash-attn/kernels/flash_fwd_launch_template.h | 4 ++++ candle-flash-attn/src/lib.rs | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) (limited to 'candle-flash-attn') diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 66ab6206..002dd8ec 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 21a06b5e..f171a986 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -139,7 +139,9 @@ impl FlashAttn { let elem_count = out_shape.elem_count(); let dst = unsafe { dev.alloc::(elem_count) }.w()?; - let softmax_lse = dev.alloc_zeros::(b_sz * num_heads * seqlen_q).w()?; + let softmax_lse = dev + .alloc_zeros::(b_sz * 128 * num_heads * seqlen_q) + .w()?; let is_bf16 = if is_bf16 { 1 } else { 0 }; -- cgit v1.2.3