summaryrefslogtreecommitdiff
path: root/candle-flash-attn/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/src/lib.rs')
-rw-r--r--candle-flash-attn/src/lib.rs4
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 21a06b5e..f171a986 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -139,7 +139,9 @@ impl FlashAttn {
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
- let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
+ let softmax_lse = dev
+ .alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
+ .w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };