summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/mistral.rs34
-rw-r--r--candle-transformers/src/models/quantized_mistral.rs14
2 files changed, 48 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index caf96bce..2a66515b 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -21,6 +21,7 @@ pub struct Config {
}
impl Config {
+ // https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
pub fn config_7b_v0_1(use_flash_attn: bool) -> Self {
Self {
vocab_size: 32000,
@@ -37,6 +38,25 @@ impl Config {
use_flash_attn,
}
}
+
+ // https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca/blob/main/config.json
+ // https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
+ pub fn config_chat_ml(use_flash_attn: bool) -> Self {
+ Self {
+ vocab_size: 32002,
+ hidden_size: 4096,
+ intermediate_size: 14336,
+ num_hidden_layers: 32,
+ num_attention_heads: 32,
+ num_key_value_heads: 8,
+ hidden_act: Activation::Silu,
+ max_position_embeddings: 32768,
+ rms_norm_eps: 1e-5,
+ rope_theta: 10_000.,
+ sliding_window: 4096,
+ use_flash_attn,
+ }
+ }
}
#[derive(Debug, Clone)]
@@ -277,6 +297,10 @@ impl Attention {
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
}
#[derive(Debug, Clone)]
@@ -320,6 +344,10 @@ impl DecoderLayer {
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache()
+ }
}
#[derive(Debug, Clone)]
@@ -403,4 +431,10 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
+
+ pub fn clear_kv_cache(&mut self) {
+ for layer in self.layers.iter_mut() {
+ layer.clear_kv_cache()
+ }
+ }
}
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs
index 9e306c67..f2cb3b27 100644
--- a/candle-transformers/src/models/quantized_mistral.rs
+++ b/candle-transformers/src/models/quantized_mistral.rs
@@ -198,6 +198,10 @@ impl Attention {
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
}
#[derive(Debug, Clone)]
@@ -241,6 +245,10 @@ impl DecoderLayer {
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache()
+ }
}
#[derive(Debug, Clone)]
@@ -322,4 +330,10 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
+
+ pub fn clear_kv_cache(&mut self) {
+ for layer in self.layers.iter_mut() {
+ layer.clear_kv_cache()
+ }
+ }
}