diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-01 21:04:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-01 20:04:52 +0000 |
commit | 6c990a33ea4635bf98b180f6e4c99e6795ccfbab (patch) | |
tree | 92e42aab766bb7dbb4bc99d6b37409bbe33db30b | |
parent | 1704f1b3aec92b07dd805411fa8065eab55e4186 (diff) | |
download | candle-6c990a33ea4635bf98b180f6e4c99e6795ccfbab.tar.gz candle-6c990a33ea4635bf98b180f6e4c99e6795ccfbab.tar.bz2 candle-6c990a33ea4635bf98b180f6e4c99e6795ccfbab.zip |
Remove the unused pragma for marian. (#1236)
-rw-r--r-- | candle-transformers/src/models/marian.rs | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index ebab3dbc..05804a1c 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -1,6 +1,5 @@ -#![allow(unused)] -use super::with_tracing::{linear, linear_no_bias, Embedding, Linear}; -use candle::{Module, Result, Tensor}; +use super::with_tracing::{linear, Embedding, Linear}; +use candle::{Result, Tensor}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; #[derive(Debug, Clone)] @@ -170,7 +169,6 @@ impl Attention { kv_states: Option<&Tensor>, attn_mask: Option<&Tensor>, ) -> Result<Tensor> { - let is_cross_attn = kv_states.is_some(); let (b_sz, tgt_len, _) = xs.dims3()?; let query_states = (xs.apply(&self.q_proj)? * self.scaling)?; let (key_states, value_states) = match kv_states { @@ -259,6 +257,10 @@ impl EncoderLayer { .apply(&self.fc2)?; (xs + residual)?.apply(&self.final_layer_norm) } + + fn reset_kv_cache(&mut self) { + self.self_attn.reset_kv_cache() + } } #[derive(Debug, Clone)] @@ -320,6 +322,11 @@ impl DecoderLayer { let xs = (xs + residual)?.apply(&self.final_layer_norm)?; Ok(xs) } + + fn reset_kv_cache(&mut self) { + self.self_attn.reset_kv_cache(); + self.encoder_attn.reset_kv_cache() + } } #[derive(Debug, Clone)] @@ -368,6 +375,12 @@ impl Encoder { } Ok(xs) } + + pub fn reset_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.reset_kv_cache() + } + } } #[derive(Debug, Clone)] @@ -422,6 +435,12 @@ impl Decoder { } Ok(xs) } + + pub fn reset_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.reset_kv_cache() + } + } } #[derive(Debug, Clone)] @@ -442,6 +461,11 @@ impl Model { decoder, }) } + + fn reset_kv_cache(&mut self) { + self.encoder.reset_kv_cache(); + self.decoder.reset_kv_cache(); + } } #[derive(Debug, Clone)] @@ -489,4 +513,8 @@ impl MTModel { .apply(&self.lm_head)? .broadcast_add(&self.final_logits_bias) } + + pub fn reset_kv_cache(&mut self) { + self.model.reset_kv_cache(); + } } |