diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-18 19:18:59 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-18 19:18:59 +0200 |
commit | 7ebc3548e19fb8d40940fa2315b3f817058e7e96 (patch) | |
tree | e8097869ef0b6f8db1a9091ef2f7df078e013cc7 /candle-flash-attn/src | |
parent | eefc1c77ef00b74e1f8c6ac4e217dfbdbd419eff (diff) | |
download | candle-7ebc3548e19fb8d40940fa2315b3f817058e7e96.tar.gz candle-7ebc3548e19fb8d40940fa2315b3f817058e7e96.tar.bz2 candle-7ebc3548e19fb8d40940fa2315b3f817058e7e96.zip |
Use flash-attn in gemma. (#2195)
* Use flash-attn in gemma.
* Fix flash-attn for head dim 256.
Diffstat (limited to 'candle-flash-attn/src')
-rw-r--r-- | candle-flash-attn/src/lib.rs | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 21a06b5e..f171a986 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -139,7 +139,9 @@ impl FlashAttn { let elem_count = out_shape.elem_count(); let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?; - let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?; + let softmax_lse = dev + .alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q) + .w()?; let is_bf16 = if is_bf16 { 1 } else { 0 }; |