diff options
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/models/mistral.rs | 34 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_mistral.rs | 14 |
2 files changed, 48 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index caf96bce..2a66515b 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -21,6 +21,7 @@ pub struct Config { } impl Config { + // https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json pub fn config_7b_v0_1(use_flash_attn: bool) -> Self { Self { vocab_size: 32000, @@ -37,6 +38,25 @@ impl Config { use_flash_attn, } } + + // https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca/blob/main/config.json + // https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json + pub fn config_chat_ml(use_flash_attn: bool) -> Self { + Self { + vocab_size: 32002, + hidden_size: 4096, + intermediate_size: 14336, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 8, + hidden_act: Activation::Silu, + max_position_embeddings: 32768, + rms_norm_eps: 1e-5, + rope_theta: 10_000., + sliding_window: 4096, + use_flash_attn, + } + } } #[derive(Debug, Clone)] @@ -277,6 +297,10 @@ impl Attention { .reshape((b_sz, q_len, self.hidden_size))? .apply(&self.o_proj) } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } } #[derive(Debug, Clone)] @@ -320,6 +344,10 @@ impl DecoderLayer { let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; residual + xs } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } } #[derive(Debug, Clone)] @@ -403,4 +431,10 @@ impl Model { .apply(&self.norm)? .apply(&self.lm_head) } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } } diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 9e306c67..f2cb3b27 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -198,6 +198,10 @@ impl Attention { .reshape((b_sz, q_len, self.hidden_size))? .apply(&self.o_proj) } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } } #[derive(Debug, Clone)] @@ -241,6 +245,10 @@ impl DecoderLayer { let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; residual + xs } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } } #[derive(Debug, Clone)] @@ -322,4 +330,10 @@ impl Model { .apply(&self.norm)? .apply(&self.lm_head) } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } } |