diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-28 18:26:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-28 17:26:33 +0100 |
commit | 23b3576c478ee46633da2b703c7961a6341f9d0f (patch) | |
tree | b5262de2ae992838bc64c395563ab970d9f4dd4a | |
parent | 716ab2ccdcb07aab26c41a98a839c31ac9760ca6 (diff) | |
download | candle-23b3576c478ee46633da2b703c7961a6341f9d0f.tar.gz candle-23b3576c478ee46633da2b703c7961a6341f9d0f.tar.bz2 candle-23b3576c478ee46633da2b703c7961a6341f9d0f.zip |
Add the sliding window. (#986)
-rw-r--r-- | candle-transformers/src/models/mistral.rs | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 33569bd8..245150e7 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -299,7 +299,6 @@ pub struct Model { layers: Vec<DecoderLayer>, norm: RmsNorm, lm_head: Linear, - #[allow(unused)] sliding_window: usize, device: Device, dtype: DType, @@ -338,7 +337,15 @@ impl Model { ) -> Result<Tensor> { // Sliding window mask? let mask: Vec<_> = (0..tgt_len) - .flat_map(|i| (0..tgt_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. })) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + self.sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) .collect(); let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; let mask = if seqlen_offset > 0 { |