summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/t5.rs30
1 files changed, 30 insertions, 0 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 2b71fcda..94cf5233 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -495,6 +495,10 @@ impl T5Attention {
let attn_output = self.o.forward(&attn_output)?;
Ok((attn_output, position_bias))
}
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
}
#[derive(Debug)]
@@ -530,6 +534,10 @@ impl T5LayerSelfAttention {
let ys = (xs + ys)?;
Ok((ys, position_bias))
}
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attention.clear_kv_cache()
+ }
}
#[derive(Debug)]
@@ -568,6 +576,10 @@ impl T5LayerCrossAttention {
let ys = (hidden_states + ys)?;
Ok((ys, position_bias))
}
+
+ fn clear_kv_cache(&mut self) {
+ self.cross_attention.clear_kv_cache()
+ }
}
#[derive(Debug)]
@@ -634,6 +646,11 @@ impl T5Block {
// TODO: clamp for f16?
Ok((xs, position_bias))
}
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache();
+ self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
+ }
}
#[derive(Debug)]
@@ -680,6 +697,10 @@ impl T5Stack {
}
self.final_layer_norm.forward(&hidden_states)
}
+
+ fn clear_kv_cache(&mut self) {
+ self.block.iter_mut().for_each(|b| b.clear_kv_cache())
+ }
}
#[derive(Debug)]
@@ -709,6 +730,10 @@ impl T5EncoderModel {
pub fn device(&self) -> &Device {
&self.device
}
+
+ pub fn clear_kv_cache(&mut self) {
+ self.encoder.clear_kv_cache()
+ }
}
#[derive(Debug)]
@@ -808,4 +833,9 @@ impl T5ForConditionalGeneration {
pub fn device(&self) -> &Device {
&self.device
}
+
+ pub fn clear_kv_cache(&mut self) {
+ self.encoder.clear_kv_cache();
+ self.decoder.clear_kv_cache();
+ }
}