summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/mistral.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/mistral.rs')
-rw-r--r--candle-transformers/src/models/mistral.rs38
1 files changed, 34 insertions, 4 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index 7e3b21c9..e8f7a7c4 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -4,19 +4,29 @@ use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
+fn default_num_attention_heads() -> usize {
+ 32
+}
+
fn default_use_flash_attn() -> bool {
false
}
+fn default_hidden_act() -> candle_nn::Activation {
+ candle_nn::Activation::Silu
+}
+
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
+ #[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
pub head_dim: Option<usize>,
pub num_key_value_heads: usize,
+ #[serde(default = "default_hidden_act")]
pub hidden_act: Activation,
pub max_position_embeddings: usize,
pub rms_norm_eps: f64,
@@ -107,14 +117,14 @@ impl RotaryEmbedding {
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
- let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
- .to_dtype(dtype)?
+ .to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
- sin: freqs.sin()?,
- cos: freqs.cos()?,
+ sin: freqs.sin()?.to_dtype(dtype)?,
+ cos: freqs.cos()?.to_dtype(dtype)?,
})
}
@@ -404,6 +414,10 @@ impl Model {
.to_dtype(self.dtype)
}
+ pub fn embed_tokens(&self) -> &candle_nn::Embedding {
+ &self.embed_tokens
+ }
+
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
@@ -421,6 +435,22 @@ impl Model {
.apply(&self.lm_head)
}
+ pub fn forward_embeds(
+ &mut self,
+ xs: &Tensor,
+ attn_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (_b_size, seq_len, _) = xs.dims3()?;
+ let mut xs = xs.clone();
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attn_mask, seqlen_offset)?
+ }
+ xs.narrow(1, seq_len - 1, 1)?
+ .apply(&self.norm)?
+ .apply(&self.lm_head)
+ }
+
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()