diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-18 19:18:59 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-18 19:18:59 +0200 |
commit | 7ebc3548e19fb8d40940fa2315b3f817058e7e96 (patch) | |
tree | e8097869ef0b6f8db1a9091ef2f7df078e013cc7 /candle-flash-attn/kernels/flash_fwd_launch_template.h | |
parent | eefc1c77ef00b74e1f8c6ac4e217dfbdbd419eff (diff) | |
download | candle-7ebc3548e19fb8d40940fa2315b3f817058e7e96.tar.gz candle-7ebc3548e19fb8d40940fa2315b3f817058e7e96.tar.bz2 candle-7ebc3548e19fb8d40940fa2315b3f817058e7e96.zip |
Use flash-attn in gemma. (#2195)
* Use flash-attn in gemma.
* Fix flash-attn for head dim 256.
Diffstat (limited to 'candle-flash-attn/kernels/flash_fwd_launch_template.h')
-rw-r--r-- | candle-flash-attn/kernels/flash_fwd_launch_template.h | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 66ab6206..002dd8ec 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); |