summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorZack Angelo <zackangelo@gmail.com>2024-10-23 11:07:09 -0700
committerGitHub <noreply@github.com>2024-10-23 20:07:09 +0200
commita2e9d41b2062be5b45c84d24fe2bf4527ec27cee (patch)
treef587b1c8dc547d2213076adc653505df0f116711 /candle-transformers
parent7c09215ef443256523d2de2579db56d1b59fd683 (diff)
downloadcandle-a2e9d41b2062be5b45c84d24fe2bf4527ec27cee.tar.gz
candle-a2e9d41b2062be5b45c84d24fe2bf4527ec27cee.tar.bz2
candle-a2e9d41b2062be5b45c84d24fe2bf4527ec27cee.zip
use softmax_last_dim (metal and cuda kernel) in llama attention layer (#2572)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/llama.rs3
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index a7bef099..e7769734 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -341,7 +341,8 @@ impl CausalSelfAttention {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
};
- let att = candle_nn::ops::softmax(&att, D::Minus1)?;
+
+ let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};