summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/mixformer.rs12
-rw-r--r--candle-transformers/src/models/quantized_mixformer.rs12
2 files changed, 24 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs
index e945cd51..b2fa2860 100644
--- a/candle-transformers/src/models/mixformer.rs
+++ b/candle-transformers/src/models/mixformer.rs
@@ -287,6 +287,10 @@ impl MHA {
.flatten_from(D::Minus2)?;
attn_output.apply(&self.out_proj)
}
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
}
#[derive(Debug)]
@@ -318,6 +322,10 @@ impl ParallelBlock {
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
attn_outputs + feed_forward_hidden_states + residual
}
+
+ fn clear_kv_cache(&mut self) {
+ self.mixer.clear_kv_cache()
+ }
}
#[derive(Debug)]
@@ -360,4 +368,8 @@ impl MixFormerSequentialForCausalLM {
}
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
}
+
+ pub fn clear_kv_cache(&mut self) {
+ self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
+ }
}
diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs
index 4ace2045..e458cf5c 100644
--- a/candle-transformers/src/models/quantized_mixformer.rs
+++ b/candle-transformers/src/models/quantized_mixformer.rs
@@ -268,6 +268,10 @@ impl MHA {
.flatten_from(D::Minus2)?;
attn_output.apply(&self.out_proj)
}
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
}
#[derive(Debug)]
@@ -299,6 +303,10 @@ impl ParallelBlock {
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
attn_outputs + feed_forward_hidden_states + residual
}
+
+ fn clear_kv_cache(&mut self) {
+ self.mixer.clear_kv_cache()
+ }
}
#[derive(Debug)]
@@ -341,4 +349,8 @@ impl MixFormerSequentialForCausalLM {
}
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
}
+
+ pub fn clear_kv_cache(&mut self) {
+ self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
+ }
}