summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-28 18:26:33 +0200
committerGitHub <noreply@github.com>2023-09-28 17:26:33 +0100
commit23b3576c478ee46633da2b703c7961a6341f9d0f (patch)
treeb5262de2ae992838bc64c395563ab970d9f4dd4a
parent716ab2ccdcb07aab26c41a98a839c31ac9760ca6 (diff)
downloadcandle-23b3576c478ee46633da2b703c7961a6341f9d0f.tar.gz
candle-23b3576c478ee46633da2b703c7961a6341f9d0f.tar.bz2
candle-23b3576c478ee46633da2b703c7961a6341f9d0f.zip
Add the sliding window. (#986)
-rw-r--r--candle-transformers/src/models/mistral.rs11
1 files changed, 9 insertions, 2 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index 33569bd8..245150e7 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -299,7 +299,6 @@ pub struct Model {
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
- #[allow(unused)]
sliding_window: usize,
device: Device,
dtype: DType,
@@ -338,7 +337,15 @@ impl Model {
) -> Result<Tensor> {
// Sliding window mask?
let mask: Vec<_> = (0..tgt_len)
- .flat_map(|i| (0..tgt_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
+ .flat_map(|i| {
+ (0..tgt_len).map(move |j| {
+ if i < j || j + self.sliding_window < i {
+ f32::NEG_INFINITY
+ } else {
+ 0.
+ }
+ })
+ })
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {