diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-21 09:01:06 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-21 09:01:06 +0100 |
commit | c89b82b2d419bd2e99ffc64c90a2615e97d4ea66 (patch) | |
tree | 88c7de4869279be54926addbaeb3627e7210aefd | |
parent | 7b26e513f15a0c7cd55ccfe48525bda1079427ce (diff) | |
download | candle-c89b82b2d419bd2e99ffc64c90a2615e97d4ea66.tar.gz candle-c89b82b2d419bd2e99ffc64c90a2615e97d4ea66.tar.bz2 candle-c89b82b2d419bd2e99ffc64c90a2615e97d4ea66.zip |
Add a clear cache function to the t5 model. (#919)
-rw-r--r-- | candle-transformers/src/models/t5.rs | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 2b71fcda..94cf5233 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -495,6 +495,10 @@ impl T5Attention { let attn_output = self.o.forward(&attn_output)?; Ok((attn_output, position_bias)) } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } } #[derive(Debug)] @@ -530,6 +534,10 @@ impl T5LayerSelfAttention { let ys = (xs + ys)?; Ok((ys, position_bias)) } + + fn clear_kv_cache(&mut self) { + self.self_attention.clear_kv_cache() + } } #[derive(Debug)] @@ -568,6 +576,10 @@ impl T5LayerCrossAttention { let ys = (hidden_states + ys)?; Ok((ys, position_bias)) } + + fn clear_kv_cache(&mut self) { + self.cross_attention.clear_kv_cache() + } } #[derive(Debug)] @@ -634,6 +646,11 @@ impl T5Block { // TODO: clamp for f16? Ok((xs, position_bias)) } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache()); + } } #[derive(Debug)] @@ -680,6 +697,10 @@ impl T5Stack { } self.final_layer_norm.forward(&hidden_states) } + + fn clear_kv_cache(&mut self) { + self.block.iter_mut().for_each(|b| b.clear_kv_cache()) + } } #[derive(Debug)] @@ -709,6 +730,10 @@ impl T5EncoderModel { pub fn device(&self) -> &Device { &self.device } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache() + } } #[derive(Debug)] @@ -808,4 +833,9 @@ impl T5ForConditionalGeneration { pub fn device(&self) -> &Device { &self.device } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache(); + self.decoder.clear_kv_cache(); + } } |