diff options
author | Zack Angelo <zackangelo@gmail.com> | 2024-10-23 11:07:09 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-23 20:07:09 +0200 |
commit | a2e9d41b2062be5b45c84d24fe2bf4527ec27cee (patch) | |
tree | f587b1c8dc547d2213076adc653505df0f116711 /candle-transformers | |
parent | 7c09215ef443256523d2de2579db56d1b59fd683 (diff) | |
download | candle-a2e9d41b2062be5b45c84d24fe2bf4527ec27cee.tar.gz candle-a2e9d41b2062be5b45c84d24fe2bf4527ec27cee.tar.bz2 candle-a2e9d41b2062be5b45c84d24fe2bf4527ec27cee.zip |
use softmax_last_dim (metal and cuda kernel) in llama attention layer (#2572)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/llama.rs | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index a7bef099..e7769734 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -341,7 +341,8 @@ impl CausalSelfAttention { let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; masked_fill(&att, &mask, f32::NEG_INFINITY)? }; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; + + let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? }; |