diff options
Diffstat (limited to 'candle-transformers/src/models/mistral.rs')
-rw-r--r-- | candle-transformers/src/models/mistral.rs | 38 |
1 files changed, 34 insertions, 4 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 7e3b21c9..e8f7a7c4 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -4,19 +4,29 @@ use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; +fn default_num_attention_heads() -> usize { + 32 +} + fn default_use_flash_attn() -> bool { false } +fn default_hidden_act() -> candle_nn::Activation { + candle_nn::Activation::Silu +} + #[derive(Debug, Clone, PartialEq, serde::Deserialize)] pub struct Config { pub vocab_size: usize, pub hidden_size: usize, pub intermediate_size: usize, pub num_hidden_layers: usize, + #[serde(default = "default_num_attention_heads")] pub num_attention_heads: usize, pub head_dim: Option<usize>, pub num_key_value_heads: usize, + #[serde(default = "default_hidden_act")] pub hidden_act: Activation, pub max_position_embeddings: usize, pub rms_norm_eps: f64, @@ -107,14 +117,14 @@ impl RotaryEmbedding { .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) .collect(); let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? + .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, + sin: freqs.sin()?.to_dtype(dtype)?, + cos: freqs.cos()?.to_dtype(dtype)?, }) } @@ -404,6 +414,10 @@ impl Model { .to_dtype(self.dtype) } + pub fn embed_tokens(&self) -> &candle_nn::Embedding { + &self.embed_tokens + } + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { let (_b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { @@ -421,6 +435,22 @@ impl Model { .apply(&self.lm_head) } + pub fn forward_embeds( + &mut self, + xs: &Tensor, + attn_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let (_b_size, seq_len, _) = xs.dims3()?; + let mut xs = xs.clone(); + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attn_mask, seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } + pub fn clear_kv_cache(&mut self) { for layer in self.layers.iter_mut() { layer.clear_kv_cache() |