diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-28 17:19:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-28 16:19:18 +0100 |
commit | ada8851a23bcddccbc5bc6e972b020401a1fe389 (patch) | |
tree | ae609e2ec1655d63e4145073af0b3285944f14f4 /candle-transformers/src | |
parent | c05a348e3626aa1b609d767317df600e0a838ae4 (diff) | |
download | candle-ada8851a23bcddccbc5bc6e972b020401a1fe389.tar.gz candle-ada8851a23bcddccbc5bc6e972b020401a1fe389.tar.bz2 candle-ada8851a23bcddccbc5bc6e972b020401a1fe389.zip |
Add the mistral example. (#984)
* Add the mistral example.
* Use the two model files.
* Adjust the dtype.
* Tweak the weight paths.
* Remove the end of text token.
* Get the mistral model to generate some text.
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/models/mistral.rs | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 4ce89081..7db83ff1 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -1,7 +1,6 @@ -#![allow(unused)] -use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear}; +use crate::models::with_tracing::{linear_no_bias, Linear}; /// Mistral LLM, https://github.com/mistralai/mistral-src -use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -99,7 +98,7 @@ impl RotaryEmbedding { k: &Tensor, seqlen_offset: usize, ) -> Result<(Tensor, Tensor)> { - let (b_sz, seq_len, h, n_embd) = q.dims4()?; + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) @@ -240,7 +239,7 @@ impl Attention { let attn_weights = match attention_mask { None => attn_weights, - Some(mask) => (attn_weights + mask)?, + Some(mask) => attn_weights.broadcast_add(mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; let attn_output = attn_weights.matmul(&value_states)?; @@ -290,7 +289,7 @@ impl DecoderLayer { let xs = (xs + residual)?; let residual = &xs; let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; - Ok(xs) + residual + xs } } @@ -300,22 +299,24 @@ pub struct Model { layers: Vec<DecoderLayer>, norm: RmsNorm, lm_head: Linear, + #[allow(unused)] sliding_window: usize, device: Device, } impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_m = vb.pp("model"); let embed_tokens = - candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?; - let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb.device())?); + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - let vb_l = vb.pp("layers"); + let vb_l = vb_m.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?; + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; Ok(Self { embed_tokens, @@ -359,6 +360,8 @@ impl Model { for layer in self.layers.iter_mut() { xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? } - xs.apply(&self.norm)?.apply(&self.lm_head) + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) } } |