summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/llama2_c.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-25 15:31:04 +0100
committerGitHub <noreply@github.com>2024-03-25 15:31:04 +0100
commitd3a8d291d5f2ff5addb9ff97cf881307afbd7b6a (patch)
tree06a083abb9e35c708a4a5736740858f46bfd45ad /candle-transformers/src/models/llama2_c.rs
parentcd254074f354c4066bc73e1c5cc5ecc84d25a2db (diff)
downloadcandle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.tar.gz
candle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.tar.bz2
candle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.zip
Avoid the attention mask where possible. (#1933)
Diffstat (limited to 'candle-transformers/src/models/llama2_c.rs')
-rw-r--r--candle-transformers/src/models/llama2_c.rs8
1 files changed, 6 insertions, 2 deletions
diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs
index 7b4f120b..bba8b666 100644
--- a/candle-transformers/src/models/llama2_c.rs
+++ b/candle-transformers/src/models/llama2_c.rs
@@ -194,8 +194,12 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
- let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
- let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
+ let att = if seq_len <= 1 {
+ att
+ } else {
+ let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
+ masked_fill(&att, &mask, f32::NEG_INFINITY)?
+ };
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;