diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-25 15:31:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-25 15:31:04 +0100 |
commit | d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a (patch) | |
tree | 06a083abb9e35c708a4a5736740858f46bfd45ad /candle-transformers/src/models/llama2_c.rs | |
parent | cd254074f354c4066bc73e1c5cc5ecc84d25a2db (diff) | |
download | candle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.tar.gz candle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.tar.bz2 candle-d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a.zip |
Avoid the attention mask where possible. (#1933)
Diffstat (limited to 'candle-transformers/src/models/llama2_c.rs')
-rw-r--r-- | candle-transformers/src/models/llama2_c.rs | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 7b4f120b..bba8b666 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -194,8 +194,12 @@ impl CausalSelfAttention { let v = v.transpose(1, 2)?.contiguous()?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = if seq_len <= 1 { + att + } else { + let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; + masked_fill(&att, &mask, f32::NEG_INFINITY)? + }; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; |