summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/marian.rs36
1 files changed, 32 insertions, 4 deletions
diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs
index ebab3dbc..05804a1c 100644
--- a/candle-transformers/src/models/marian.rs
+++ b/candle-transformers/src/models/marian.rs
@@ -1,6 +1,5 @@
-#![allow(unused)]
-use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
-use candle::{Module, Result, Tensor};
+use super::with_tracing::{linear, Embedding, Linear};
+use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
#[derive(Debug, Clone)]
@@ -170,7 +169,6 @@ impl Attention {
kv_states: Option<&Tensor>,
attn_mask: Option<&Tensor>,
) -> Result<Tensor> {
- let is_cross_attn = kv_states.is_some();
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
let (key_states, value_states) = match kv_states {
@@ -259,6 +257,10 @@ impl EncoderLayer {
.apply(&self.fc2)?;
(xs + residual)?.apply(&self.final_layer_norm)
}
+
+ fn reset_kv_cache(&mut self) {
+ self.self_attn.reset_kv_cache()
+ }
}
#[derive(Debug, Clone)]
@@ -320,6 +322,11 @@ impl DecoderLayer {
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
Ok(xs)
}
+
+ fn reset_kv_cache(&mut self) {
+ self.self_attn.reset_kv_cache();
+ self.encoder_attn.reset_kv_cache()
+ }
}
#[derive(Debug, Clone)]
@@ -368,6 +375,12 @@ impl Encoder {
}
Ok(xs)
}
+
+ pub fn reset_kv_cache(&mut self) {
+ for layer in self.layers.iter_mut() {
+ layer.reset_kv_cache()
+ }
+ }
}
#[derive(Debug, Clone)]
@@ -422,6 +435,12 @@ impl Decoder {
}
Ok(xs)
}
+
+ pub fn reset_kv_cache(&mut self) {
+ for layer in self.layers.iter_mut() {
+ layer.reset_kv_cache()
+ }
+ }
}
#[derive(Debug, Clone)]
@@ -442,6 +461,11 @@ impl Model {
decoder,
})
}
+
+ fn reset_kv_cache(&mut self) {
+ self.encoder.reset_kv_cache();
+ self.decoder.reset_kv_cache();
+ }
}
#[derive(Debug, Clone)]
@@ -489,4 +513,8 @@ impl MTModel {
.apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias)
}
+
+ pub fn reset_kv_cache(&mut self) {
+ self.model.reset_kv_cache();
+ }
}