summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/llama.rs3
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index a7bef099..e7769734 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -341,7 +341,8 @@ impl CausalSelfAttention {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
};
- let att = candle_nn::ops::softmax(&att, D::Minus1)?;
+
+ let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};