summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/model.rs6
1 files changed, 5 insertions, 1 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index 0e850b6a..049d0c38 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -220,8 +220,12 @@ impl CausalSelfAttention {
let v = self.repeat_kv(v)?;
let y = if self.use_flash_attn {
+ // flash-attn expects (b_sz, seq_len, nheads, head_dim)
+ let q = q.transpose(1, 2)?;
+ let k = k.transpose(1, 2)?;
+ let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
- flash_attn(softmax_scale, &q, &k, &v)?
+ flash_attn(softmax_scale, &q, &k, &v)?.transpose(1, 2)?
} else {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;