summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/mistral.rs25
1 files changed, 14 insertions, 11 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index 4ce89081..7db83ff1 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -1,7 +1,6 @@
-#![allow(unused)]
-use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear};
+use crate::models::with_tracing::{linear_no_bias, Linear};
/// Mistral LLM, https://github.com/mistralai/mistral-src
-use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
+use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
@@ -99,7 +98,7 @@ impl RotaryEmbedding {
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
- let (b_sz, seq_len, h, n_embd) = q.dims4()?;
+ let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
@@ -240,7 +239,7 @@ impl Attention {
let attn_weights = match attention_mask {
None => attn_weights,
- Some(mask) => (attn_weights + mask)?,
+ Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&value_states)?;
@@ -290,7 +289,7 @@ impl DecoderLayer {
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
- Ok(xs)
+ residual + xs
}
}
@@ -300,22 +299,24 @@ pub struct Model {
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
+ #[allow(unused)]
sliding_window: usize,
device: Device,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_m = vb.pp("model");
let embed_tokens =
- candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
- let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb.device())?);
+ candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
+ let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
- let vb_l = vb.pp("layers");
+ let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
- let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
+ let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
@@ -359,6 +360,8 @@ impl Model {
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
- xs.apply(&self.norm)?.apply(&self.lm_head)
+ xs.narrow(1, seq_len - 1, 1)?
+ .apply(&self.norm)?
+ .apply(&self.lm_head)
}
}