summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/examples/tensor-tools.rs36
-rw-r--r--candle-examples/examples/mistral/README.md50
-rw-r--r--candle-examples/examples/mistral/main.rs53
-rw-r--r--candle-examples/src/token_output_stream.rs16
-rw-r--r--candle-transformers/src/models/mistral.rs24
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/quantized_mistral.rs364
7 files changed, 507 insertions, 37 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index 3982f2c3..d06b30d1 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -103,8 +103,10 @@ enum Command {
Quantize {
/// The input file, in gguf format.
- in_file: std::path::PathBuf,
+ in_file: Vec<std::path::PathBuf>,
+
/// The output file, in gguf format.
+ #[arg(long)]
out_file: std::path::PathBuf,
/// The quantization schema to apply.
@@ -218,12 +220,16 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
}
fn run_quantize_safetensors(
- in_file: std::path::PathBuf,
+ in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
q: Quantization,
) -> Result<()> {
let mut out_file = std::fs::File::create(out_file)?;
- let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
+ let mut tensors = std::collections::HashMap::new();
+ for in_file in in_files.iter() {
+ let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
+ tensors.extend(in_tensors)
+ }
println!("tensors: {}", tensors.len());
let quantize_fn = match q {
@@ -280,20 +286,32 @@ fn run_quantize_safetensors(
}
fn run_quantize(
- in_file: std::path::PathBuf,
+ in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
q: Quantization,
qmode: QuantizationMode,
) -> Result<()> {
- if let Some(extension) = in_file.extension() {
+ if in_files.is_empty() {
+ candle_core::bail!("no specified input files")
+ }
+ if let Some(extension) = out_file.extension() {
+ if extension == "safetensors" {
+ candle_core::bail!("the generated file cannot use the safetensors extension")
+ }
+ }
+ if let Some(extension) = in_files[0].extension() {
if extension == "safetensors" {
- return run_quantize_safetensors(in_file, out_file, q);
+ return run_quantize_safetensors(in_files, out_file, q);
}
}
+ if in_files.len() != 1 {
+ candle_core::bail!("only a single in-file can be used when quantizing gguf files")
+ }
+
// Open the out file early so as to fail directly on missing directories etc.
let mut out_file = std::fs::File::create(out_file)?;
- let mut in_ = std::fs::File::open(&in_file)?;
+ let mut in_ = std::fs::File::open(&in_files[0])?;
let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len());
@@ -319,7 +337,7 @@ fn run_quantize(
.par_iter()
.map(|(name, _)| {
println!(" quantizing {name}");
- let mut in_file = std::fs::File::open(&in_file)?;
+ let mut in_file = std::fs::File::open(&in_files[0])?;
let tensor = content.tensor(&mut in_file, name)?;
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
Ok((name, tensor))
@@ -360,7 +378,7 @@ fn main() -> anyhow::Result<()> {
out_file,
quantization,
mode,
- } => run_quantize(in_file, out_file, quantization, mode)?,
+ } => run_quantize(&in_file, out_file, quantization, mode)?,
}
Ok(())
}
diff --git a/candle-examples/examples/mistral/README.md b/candle-examples/examples/mistral/README.md
index 6a5a0424..61a6666e 100644
--- a/candle-examples/examples/mistral/README.md
+++ b/candle-examples/examples/mistral/README.md
@@ -6,6 +6,9 @@ as of 2023-09-28. Weights (and the original Python model code) are released unde
- [Blog post](https://mistral.ai/news/announcing-mistral-7b/) from Mistral announcing the model release.
- [Model card](https://huggingface.co/mistralai/Mistral-7B-v0.1) on the
HuggingFace Hub.
+This example supports the initial model as well as a quantized variant.
+
+## Running the example
```bash
$ cargo run --example mistral --release --features cuda -- --prompt 'Write helloworld code in Rust' --sample-len 150
@@ -38,3 +41,50 @@ fn main() {
This example is released under the terms
```
+
+## Running the quantized version of the model
+
+```bash
+$ cargo run --example mistral --features accelerate --release -- \
+$ --prompt "Here is a sample quick sort implementation in rust " --quantized -n 400
+avx: false, neon: true, simd128: false, f16c: false
+temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
+retrieved the files in 562.292µs
+loaded the model in 1.100323667s
+Here is a sample quick sort implementation in rust
+
+``rust
+fn quick_sort(arr: &mut [i32]) {
+ if arr.len() <= 1 {
+ return;
+ }
+
+ let pivot = arr[0];
+ let mut left = vec![];
+ let mut right = vec![];
+
+ for i in 1..arr.len() {
+ if arr[i] < pivot {
+ left.push(arr[i]);
+ } else {
+ right.push(arr[i]);
+ }
+ }
+
+ quick_sort(&mut left);
+ quick_sort(&mut right);
+
+ let mut i = 0;
+ for _ in &left {
+ arr[i] = left.pop().unwrap();
+ i += 1;
+ }
+
+ for _ in &right {
+ arr[i] = right.pop().unwrap();
+ i += 1;
+ }
+}
+``
+226 tokens generated (10.91 token/s)
+```
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index 6fe08963..18f18e5d 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
-use candle_transformers::models::mistral::{Config, Model};
+use candle_transformers::models::mistral::{Config, Model as Mistral};
+use candle_transformers::models::quantized_mistral::Model as QMistral;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@@ -16,6 +17,11 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
+enum Model {
+ Mistral(Mistral),
+ Quantized(QMistral),
+}
+
struct TextGeneration {
model: Model,
device: Device,
@@ -76,7 +82,10 @@ impl TextGeneration {
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
- let logits = self.model.forward(&input, start_pos)?;
+ let logits = match &mut self.model {
+ Model::Mistral(m) => m.forward(&input, start_pos)?,
+ Model::Quantized(m) => m.forward(&input, start_pos)?,
+ };
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
@@ -101,8 +110,9 @@ impl TextGeneration {
}
}
let dt = start_gen.elapsed();
- let rest = self.tokenizer.decode_rest().map_err(E::msg)?;
- print!("{rest}");
+ if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
+ print!("{rest}");
+ }
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
@@ -211,24 +221,39 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
- None => vec![
- repo.get("pytorch_model-00001-of-00002.safetensors")?,
- repo.get("pytorch_model-00002-of-00002.safetensors")?,
- ],
+ None => {
+ if args.quantized {
+ vec![repo.get("model-q4k.gguf")?]
+ } else {
+ vec![
+ repo.get("pytorch_model-00001-of-00002.safetensors")?,
+ repo.get("pytorch_model-00002-of-00002.safetensors")?,
+ ]
+ }
+ }
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn);
- let device = candle_examples::device(args.cpu)?;
- let dtype = if device.is_cuda() {
- DType::BF16
+ let (model, device) = if args.quantized {
+ let filename = &filenames[0];
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
+ let model = QMistral::new(&config, vb)?;
+ (Model::Quantized(model), Device::Cpu)
} else {
- DType::F32
+ let device = candle_examples::device(args.cpu)?;
+ let dtype = if device.is_cuda() {
+ DType::BF16
+ } else {
+ DType::F32
+ };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
+ let model = Mistral::new(&config, vb)?;
+ (Model::Mistral(model), device)
};
- let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
- let model = Model::new(&config, vb)?;
+
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
diff --git a/candle-examples/src/token_output_stream.rs b/candle-examples/src/token_output_stream.rs
index 3d975d63..907d8ddd 100644
--- a/candle-examples/src/token_output_stream.rs
+++ b/candle-examples/src/token_output_stream.rs
@@ -50,8 +50,20 @@ impl TokenOutputStream {
}
}
- pub fn decode_rest(&self) -> Result<String> {
- self.decode(&self.tokens[self.prev_index..])
+ pub fn decode_rest(&self) -> Result<Option<String>> {
+ let prev_text = if self.tokens.is_empty() {
+ String::new()
+ } else {
+ let tokens = &self.tokens[self.prev_index..self.current_index];
+ self.decode(tokens)?
+ };
+ let text = self.decode(&self.tokens[self.prev_index..])?;
+ if text.len() > prev_text.len() {
+ let text = text.split_at(prev_text.len());
+ Ok(Some(text.1.to_string()))
+ } else {
+ Ok(None)
+ }
}
pub fn decode_all(&self) -> Result<String> {
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index a7b4c21b..e0ecee7b 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -6,18 +6,18 @@ use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
- vocab_size: usize,
- hidden_size: usize,
- intermediate_size: usize,
- num_hidden_layers: usize,
- num_attention_heads: usize,
- num_key_value_heads: usize,
- hidden_act: Activation,
- max_position_embeddings: usize,
- rms_norm_eps: f64,
- rope_theta: f64,
- sliding_window: usize,
- use_flash_attn: bool,
+ pub(crate) vocab_size: usize,
+ pub(crate) hidden_size: usize,
+ pub(crate) intermediate_size: usize,
+ pub(crate) num_hidden_layers: usize,
+ pub(crate) num_attention_heads: usize,
+ pub(crate) num_key_value_heads: usize,
+ pub(crate) hidden_act: Activation,
+ pub(crate) max_position_embeddings: usize,
+ pub(crate) rms_norm_eps: f64,
+ pub(crate) rope_theta: f64,
+ pub(crate) sliding_window: usize,
+ pub(crate) use_flash_attn: bool,
}
impl Config {
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 15d884c6..b1544579 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -7,6 +7,7 @@ pub mod llama;
pub mod mistral;
pub mod mixformer;
pub mod quantized_llama;
+pub mod quantized_mistral;
pub mod quantized_mixformer;
pub mod quantized_t5;
pub mod segment_anything;
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs
new file mode 100644
index 00000000..171e7440
--- /dev/null
+++ b/candle-transformers/src/models/quantized_mistral.rs
@@ -0,0 +1,364 @@
+use crate::models::quantized_t5::Embedding;
+use crate::models::with_tracing::QMatMul;
+pub use crate::quantized_var_builder::VarBuilder;
+use candle::{DType, Device, Module, Result, Tensor, D};
+use candle_nn::Activation;
+use std::sync::Arc;
+
+pub use crate::models::mistral::Config;
+
+#[derive(Debug)]
+struct Linear {
+ weight: QMatMul,
+}
+
+impl Module for Linear {
+ fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
+ x.apply(&self.weight)
+ }
+}
+
+fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = QMatMul::new(in_dim, out_dim, vb)?;
+ Ok(Linear { weight })
+}
+
+#[derive(Debug)]
+struct RmsNorm {
+ inner: candle_nn::RmsNorm,
+ span: tracing::Span,
+}
+
+impl RmsNorm {
+ fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
+ let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
+ let inner = candle_nn::RmsNorm::new(weight, eps);
+ Ok(Self { inner, span })
+ }
+}
+
+impl Module for RmsNorm {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+#[derive(Debug)]
+struct RotaryEmbedding {
+ sin: Tensor,
+ cos: Tensor,
+}
+
+fn rotate_half(xs: &Tensor) -> Result<Tensor> {
+ let last_dim = xs.dim(D::Minus1)?;
+ let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
+ let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
+ Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
+}
+
+impl RotaryEmbedding {
+ fn new(cfg: &Config, dev: &Device) -> Result<Self> {
+ let dim = cfg.hidden_size / cfg.num_attention_heads;
+ let max_seq_len = cfg.max_position_embeddings;
+ let inv_freq: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
+ let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
+ .to_dtype(DType::F32)?
+ .reshape((max_seq_len, 1))?;
+ let freqs = t.matmul(&inv_freq)?;
+ let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
+ Ok(Self {
+ sin: freqs.sin()?,
+ cos: freqs.cos()?,
+ })
+ }
+
+ fn apply_rotary_emb_qkv(
+ &self,
+ q: &Tensor,
+ k: &Tensor,
+ seqlen_offset: usize,
+ ) -> Result<(Tensor, Tensor)> {
+ let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
+ let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
+ let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
+ let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
+ let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
+ let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
+ let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
+ Ok((q_embed, k_embed))
+ }
+}
+
+#[derive(Debug)]
+#[allow(clippy::upper_case_acronyms)]
+struct MLP {
+ gate_proj: Linear,
+ up_proj: Linear,
+ down_proj: Linear,
+ act_fn: Activation,
+}
+
+impl MLP {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_sz = cfg.hidden_size;
+ let intermediate_sz = cfg.intermediate_size;
+ let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
+ let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
+ let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
+ Ok(Self {
+ gate_proj,
+ up_proj,
+ down_proj,
+ act_fn: cfg.hidden_act,
+ })
+ }
+}
+
+impl Module for MLP {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
+ let rhs = xs.apply(&self.up_proj)?;
+ (lhs * rhs)?.apply(&self.down_proj)
+ }
+}
+
+#[derive(Debug)]
+struct Attention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ num_heads: usize,
+ num_kv_heads: usize,
+ num_kv_groups: usize,
+ head_dim: usize,
+ hidden_size: usize,
+ rotary_emb: Arc<RotaryEmbedding>,
+ kv_cache: Option<(Tensor, Tensor)>,
+}
+
+impl Attention {
+ fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_sz = cfg.hidden_size;
+ let num_heads = cfg.num_attention_heads;
+ let num_kv_heads = cfg.num_key_value_heads;
+ let num_kv_groups = num_heads / num_kv_heads;
+ let head_dim = hidden_sz / num_heads;
+ let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
+ let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
+ let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
+ let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ num_heads,
+ num_kv_heads,
+ num_kv_groups,
+ head_dim,
+ hidden_size: hidden_sz,
+ rotary_emb,
+ kv_cache: None,
+ })
+ }
+
+ fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
+ let n_rep = self.num_kv_groups;
+ if n_rep == 1 {
+ Ok(xs)
+ } else {
+ let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
+ xs.unsqueeze(2)?
+ .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
+ .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
+ }
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (b_sz, q_len, _) = xs.dims3()?;
+
+ let query_states = self.q_proj.forward(xs)?;
+ let key_states = self.k_proj.forward(xs)?;
+ let value_states = self.v_proj.forward(xs)?;
+
+ let query_states = query_states
+ .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let key_states = key_states
+ .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let value_states = value_states
+ .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+
+ let (query_states, key_states) =
+ self.rotary_emb
+ .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
+
+ let (key_states, value_states) = match &self.kv_cache {
+ None => (key_states, value_states),
+ Some((prev_k, prev_v)) => {
+ let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
+ let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
+ (key_states, value_states)
+ }
+ };
+ self.kv_cache = Some((key_states.clone(), value_states.clone()));
+
+ let key_states = self.repeat_kv(key_states)?;
+ let value_states = self.repeat_kv(value_states)?;
+
+ let attn_output = {
+ let scale = 1f64 / f64::sqrt(self.head_dim as f64);
+ let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
+
+ let attn_weights = match attention_mask {
+ None => attn_weights,
+ Some(mask) => attn_weights.broadcast_add(mask)?,
+ };
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+ attn_weights.matmul(&value_states)?
+ };
+ attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, q_len, self.hidden_size))?
+ .apply(&self.o_proj)
+ }
+}
+
+#[derive(Debug)]
+struct DecoderLayer {
+ self_attn: Attention,
+ mlp: MLP,
+ input_layernorm: RmsNorm,
+ post_attention_layernorm: RmsNorm,
+}
+
+impl DecoderLayer {
+ fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
+ let mlp = MLP::new(cfg, vb.pp("mlp"))?;
+ let input_layernorm =
+ RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let post_attention_layernorm = RmsNorm::new(
+ cfg.hidden_size,
+ cfg.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+ Ok(Self {
+ self_attn,
+ mlp,
+ input_layernorm,
+ post_attention_layernorm,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.input_layernorm.forward(xs)?;
+ let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
+ let xs = (xs + residual)?;
+ let residual = &xs;
+ let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
+ residual + xs
+ }
+}
+
+#[derive(Debug)]
+pub struct Model {
+ embed_tokens: Embedding,
+ layers: Vec<DecoderLayer>,
+ norm: RmsNorm,
+ lm_head: Linear,
+ sliding_window: usize,
+ device: Device,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_m = vb.pp("model");
+ let embed_tokens =
+ Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
+ let rotary_emb = Arc::new(RotaryEmbedding::new(cfg, vb_m.device())?);
+ let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb_l = vb_m.pp("layers");
+ for layer_idx in 0..cfg.num_hidden_layers {
+ let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
+ let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
+ Ok(Self {
+ embed_tokens,
+ layers,
+ norm,
+ lm_head,
+ sliding_window: cfg.sliding_window,
+ device: vb.device().clone(),
+ })
+ }
+
+ fn prepare_decoder_attention_mask(
+ &self,
+ b_size: usize,
+ tgt_len: usize,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ // Sliding window mask?
+ let mask: Vec<_> = (0..tgt_len)
+ .flat_map(|i| {
+ (0..tgt_len).map(move |j| {
+ if i < j || j + self.sliding_window < i {
+ f32::NEG_INFINITY
+ } else {
+ 0.
+ }
+ })
+ })
+ .collect();
+ let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
+ let mask = if seqlen_offset > 0 {
+ let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
+ Tensor::cat(&[&mask0, &mask], D::Minus1)?
+ } else {
+ mask
+ };
+ mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
+ .to_dtype(DType::F32)
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ let (b_size, seq_len) = input_ids.dims2()?;
+ let attention_mask = if seq_len <= 1 {
+ None
+ } else {
+ let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
+ Some(mask)
+ };
+ let mut xs = self.embed_tokens.forward(input_ids)?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
+ }
+ xs.narrow(1, seq_len - 1, 1)?
+ .apply(&self.norm)?
+ .apply(&self.lm_head)
+ }
+}