summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/replit-code/main.rs41
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/mpt.rs8
-rw-r--r--candle-transformers/src/models/quantized_mpt.rs201
-rw-r--r--candle-transformers/src/quantized_nn.rs5
5 files changed, 247 insertions, 9 deletions
diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs
index 82c6c980..0f72b862 100644
--- a/candle-examples/examples/replit-code/main.rs
+++ b/candle-examples/examples/replit-code/main.rs
@@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
-use candle_transformers::models::mpt::{Config, Model};
+use candle_transformers::models::mpt::{Config, Model as M};
+use candle_transformers::models::quantized_mpt::Model as Q;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
@@ -15,6 +16,20 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
+enum Model {
+ M(M),
+ Q(Q),
+}
+
+impl Model {
+ fn forward(&mut self, xs: &Tensor) -> candle::Result<Tensor> {
+ match self {
+ Self::M(model) => model.forward(xs),
+ Self::Q(model) => model.forward(xs),
+ }
+ }
+}
+
struct TextGeneration {
model: Model,
device: Device,
@@ -149,6 +164,9 @@ struct Args {
revision: Option<String>,
#[arg(long)]
+ quantized: bool,
+
+ #[arg(long)]
weight_file: Option<String>,
#[arg(long)]
@@ -206,16 +224,29 @@ fn main() -> Result<()> {
};
let filename = match args.weight_file {
Some(weight_file) => std::path::PathBuf::from(weight_file),
- None => repo.get("model.safetensors")?,
+ None => {
+ if args.quantized {
+ repo.get("model-replit-code-v1_5-q4k.gguf")?
+ } else {
+ repo.get("model.safetensors")?
+ }
+ }
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::replit_code_v1_5_3b();
- let device = candle_examples::device(args.cpu)?;
- let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
- let model = Model::new(&config, vb.pp("transformer"))?;
+ let (model, device) = if args.quantized {
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
+ let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
+ (model, Device::Cpu)
+ } else {
+ let device = candle_examples::device(args.cpu)?;
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
+ let model = Model::M(M::new(&config, vb.pp("transformer"))?);
+ (model, device)
+ };
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 8a02e2da..fc57e732 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -11,6 +11,7 @@ pub mod mpt;
pub mod quantized_llama;
pub mod quantized_mistral;
pub mod quantized_mixformer;
+pub mod quantized_mpt;
pub mod quantized_stable_lm;
pub mod quantized_t5;
pub mod segment_anything;
diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs
index 300a1d57..0d91bf94 100644
--- a/candle-transformers/src/models/mpt.rs
+++ b/candle-transformers/src/models/mpt.rs
@@ -137,7 +137,7 @@ impl GroupedQueryAttention {
// This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
// The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
// (batch, num_attention_heads, seqlen, head_dim)
-fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
+pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
if n_rep == 1 {
Ok(xs)
} else {
@@ -206,7 +206,7 @@ impl MPTBlock {
}
}
-fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
+pub(crate) fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
let full = !cfg.is_causal();
let seq_len = cfg.max_seq_len;
let alibi_bias = Tensor::arange(1 - seq_len as i64, 1, &Device::Cpu)?;
@@ -289,14 +289,14 @@ impl Model {
}
}
-fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
+pub(crate) fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device)
}
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+pub(crate) fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs
new file mode 100644
index 00000000..7586e4c0
--- /dev/null
+++ b/candle-transformers/src/models/quantized_mpt.rs
@@ -0,0 +1,201 @@
+use crate::quantized_nn::{layer_norm_no_bias, linear_no_bias, Embedding, Linear};
+pub use crate::quantized_var_builder::VarBuilder;
+/// MPT model used by replit-code-v1_5-3b
+/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
+use candle::{IndexOp, Module, Result, Tensor, D};
+use candle_nn::LayerNorm;
+
+pub use super::mpt::Config;
+
+#[derive(Debug)]
+struct GroupedQueryAttention {
+ wqkv: Linear,
+ out_proj: Linear,
+ kv_cache: Option<(Tensor, Tensor)>,
+ softmax_scale: f64,
+ head_dim: usize,
+ d_model: usize,
+ n_heads: usize,
+ kv_n_heads: usize,
+ attn_bias: Tensor,
+ span: tracing::Span,
+}
+
+impl GroupedQueryAttention {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let head_dim = cfg.d_model / cfg.n_heads;
+ let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
+ let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
+ let softmax_scale = 1f64 / (head_dim as f64).sqrt();
+ let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
+ let attn_bias = super::mpt::build_alibi_bias(cfg)?.to_device(vb.device())?;
+ Ok(Self {
+ wqkv,
+ out_proj,
+ kv_cache: None,
+ softmax_scale,
+ head_dim,
+ d_model: cfg.d_model,
+ n_heads: cfg.n_heads,
+ kv_n_heads: cfg.kv_n_heads,
+ attn_bias,
+ span: tracing::span!(tracing::Level::TRACE, "gqa"),
+ })
+ }
+
+ fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (b_size, seq_len, _n_embd) = xs.dims3()?;
+ let qkv = self.wqkv.forward(xs)?;
+ let query = qkv.narrow(2, 0, self.d_model)?;
+ let kv_size = self.kv_n_heads * self.head_dim;
+ let key = qkv.narrow(2, self.d_model, kv_size)?;
+ let value = qkv.narrow(2, self.d_model + kv_size, kv_size)?;
+ // scaled_multihead_dot_product_attention
+ let query = query
+ .reshape((b_size, seq_len, self.n_heads, ()))?
+ .transpose(1, 2)?; // b,h,s,d
+ let key = key
+ .reshape((b_size, seq_len, self.kv_n_heads, ()))?
+ .permute((0, 2, 3, 1))?; // b,h,d,s
+ let value = value
+ .reshape((b_size, seq_len, self.kv_n_heads, ()))?
+ .transpose(1, 2)?; // b,h,s,d
+ let (key, value) = match &self.kv_cache {
+ None => (key, value),
+ Some((prev_k, prev_v)) => {
+ let k = Tensor::cat(&[prev_k, &key], 3)?;
+ let v = Tensor::cat(&[prev_v, &value], 2)?;
+ (k, v)
+ }
+ };
+ self.kv_cache = Some((key.clone(), value.clone()));
+ let query = query.contiguous()?;
+ let key = super::mpt::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
+ let value = super::mpt::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
+ let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
+ let attn_bias = {
+ let s_q = query.dim(D::Minus2)?;
+ let s_k = key.dim(D::Minus1)?;
+ let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
+ let start_q = a_q.saturating_sub(s_q);
+ let start_k = a_k.saturating_sub(s_k);
+ self.attn_bias.i((.., .., start_q.., start_k..))?
+ };
+ let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
+ let attn_weights = match mask {
+ None => attn_weights,
+ Some(mask) => super::mpt::masked_fill(
+ &attn_weights,
+ &mask.broadcast_as(attn_weights.shape())?,
+ f32::NEG_INFINITY,
+ )?,
+ };
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+ let attn_output = attn_weights
+ .matmul(&value)?
+ .transpose(1, 2)?
+ .flatten_from(D::Minus2)?;
+ let out = attn_output.apply(&self.out_proj)?;
+ Ok(out)
+ }
+}
+
+#[derive(Debug)]
+struct Ffn {
+ up_proj: Linear,
+ down_proj: Linear,
+}
+
+impl Ffn {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden = cfg.d_model * cfg.expansion_ratio;
+ let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
+ let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
+ Ok(Self { up_proj, down_proj })
+ }
+}
+
+impl Module for Ffn {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.up_proj)?.gelu_erf()?.apply(&self.down_proj)
+ }
+}
+
+#[derive(Debug)]
+struct MPTBlock {
+ norm1: LayerNorm, // Do we need the low-precision variant?
+ attn: GroupedQueryAttention,
+ norm2: LayerNorm,
+ ffn: Ffn,
+}
+
+impl MPTBlock {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let norm1 = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_1"))?;
+ let norm2 = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_2"))?;
+ let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
+ let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
+ Ok(Self {
+ norm1,
+ attn,
+ norm2,
+ ffn,
+ })
+ }
+
+ fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
+ let residual = xs;
+ let xs = xs.apply(&self.norm1)?;
+ let xs = self.attn.forward(&xs, mask)?;
+ let xs = (xs + residual)?;
+ let residual = &xs;
+ let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug)]
+pub struct Model {
+ wte: Embedding,
+ blocks: Vec<MPTBlock>,
+ norm_f: LayerNorm,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let wte = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?;
+ let vb_b = vb.pp("blocks");
+ let mut blocks = Vec::with_capacity(cfg.n_layers);
+ for i in 0..cfg.n_layers {
+ let block = MPTBlock::new(cfg, vb_b.pp(i))?;
+ blocks.push(block)
+ }
+ let norm_f = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
+ Ok(Self {
+ wte,
+ blocks,
+ norm_f,
+ })
+ }
+
+ pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
+ let (_b_size, seq_len) = xs.dims2()?;
+ let mut xs = xs.apply(&self.wte)?;
+ let mask = if seq_len <= 1 {
+ None
+ } else {
+ Some(super::mpt::get_mask(seq_len, xs.device())?)
+ };
+ for block in self.blocks.iter_mut() {
+ xs = block.forward(&xs, mask.as_ref())?;
+ }
+ let xs = xs.apply(&self.norm_f)?;
+ let logits = xs
+ .narrow(1, seq_len - 1, 1)?
+ .squeeze(1)?
+ .matmul(&self.wte.embeddings().t()?)?
+ .squeeze(1)?;
+ Ok(logits)
+ }
+}
diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs
index 1745327d..d71c3b60 100644
--- a/candle-transformers/src/quantized_nn.rs
+++ b/candle-transformers/src/quantized_nn.rs
@@ -59,6 +59,11 @@ pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::La
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
}
+pub fn layer_norm_no_bias(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
+ let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
+ Ok(candle_nn::LayerNorm::new_no_bias(weight, eps))
+}
+
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
let weight = QMatMul::new(in_dim, out_dim, vb)?;
Ok(Linear { weight, bias: None })