diff options
-rw-r--r-- | README.md | 3 | ||||
-rw-r--r-- | candle-core/benches/matmul.rs | 1 | ||||
-rw-r--r-- | candle-core/src/quantized/gguf_file.rs | 2 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 28 | ||||
-rw-r--r-- | candle-examples/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-examples/examples/mamba-minimal/README.md | 12 | ||||
-rw-r--r-- | candle-examples/examples/mamba-minimal/main.rs | 287 | ||||
-rw-r--r-- | candle-examples/examples/mamba-minimal/model.rs | 204 | ||||
-rw-r--r-- | candle-examples/examples/mistral/main.rs | 20 | ||||
-rw-r--r-- | candle-examples/examples/phi/main.rs | 117 | ||||
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 19 | ||||
-rw-r--r-- | candle-transformers/src/models/mistral.rs | 34 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_mistral.rs | 14 |
13 files changed, 706 insertions, 36 deletions
@@ -65,6 +65,8 @@ We also provide a some command line based examples using state of the art models - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM pre-trained on 1T tokens of English and code datasets. +- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal + implementation of the Mamba state space model. - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with better performance than all publicly available 13b models as of 2023-09-28. - [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of @@ -177,6 +179,7 @@ If you have an addition to this list, please submit a pull request. - Falcon. - StarCoder. - Phi 1, 1.5, and 2. + - Minimal Mamba - Mistral 7b v0.1. - Mixtral 8x7b v0.1. - StableLM-3B-4E1T. diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs index 8732f451..83679771 100644 --- a/candle-core/benches/matmul.rs +++ b/candle-core/benches/matmul.rs @@ -40,4 +40,3 @@ fn criterion_benchmark(c: &mut Criterion) { criterion_group!(benches, criterion_benchmark); criterion_main!(benches); - diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index 1e9dc517..587ffc0f 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -41,7 +41,7 @@ impl VersionedMagic { (Magic::Gguf, 1) => Self::GgufV1, (Magic::Gguf, 2) => Self::GgufV2, (Magic::Gguf, 3) => Self::GgufV3, - _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"), + _ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"), }; Ok(versioned_magic) } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f15f8c1c..54f9fa2b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -396,7 +396,7 @@ impl Tensor { device: &Device, ) -> Result<Self> { if D::is_zero(&step) { - crate::bail!("step cannot be zero") + bail!("step cannot be zero") } let mut data = vec![]; let mut current = start; @@ -1041,6 +1041,9 @@ impl Tensor { let kernel_size = kernel_size.to_usize2(); let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; + if h < kernel_size.0 || w < kernel_size.1 { + bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}") + } // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d let h_out = (h - kernel_size.0) / stride.0 + 1; let w_out = (w - kernel_size.1) / stride.1 + 1; @@ -1076,6 +1079,9 @@ impl Tensor { let kernel_size = kernel_size.to_usize2(); let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; + if h < kernel_size.0 || w < kernel_size.1 { + bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}") + } // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d let h_out = (h - kernel_size.0) / stride.0 + 1; let w_out = (w - kernel_size.1) / stride.1 + 1; @@ -1798,7 +1804,7 @@ impl Tensor { let is_permutation = dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i)); if !is_permutation { - crate::bail!( + bail!( "dimension mismatch in permute, tensor {:?}, dims: {:?}", self.dims(), dims @@ -2293,7 +2299,7 @@ impl Tensor { if left == 0 && right == 0 { Ok(self.clone()) } else if self.elem_count() == 0 { - crate::bail!("cannot use pad_with_same on an empty tensor") + bail!("cannot use pad_with_same on an empty tensor") } else if left == 0 { let dim = dim.to_index(self.shape(), "pad_with_same")?; let r = self.narrow(dim, self.dim(dim)? - 1, 1)?; @@ -2457,13 +2463,13 @@ impl Tensor { pub fn normalize_axis(&self, axis: i64) -> Result<usize> { let rank = self.rank() as i64; if rank <= axis { - crate::bail!("axis {axis} is too large, tensor rank {rank}") + bail!("axis {axis} is too large, tensor rank {rank}") } else if 0 <= axis { Ok(axis as usize) } else { let naxis = rank + axis; if naxis < 0 { - crate::bail!("axis {axis} is too small, tensor rank {rank}") + bail!("axis {axis} is too small, tensor rank {rank}") } Ok(naxis as usize) } @@ -2525,14 +2531,14 @@ impl Tensor { let src_dims = src.dims(); let self_dims = self.dims(); if self_dims.len() != src_dims.len() { - crate::bail!( + bail!( "slice-assign requires input with the same rank {} <> {}", self_dims.len(), src_dims.len() ) } if self_dims.len() != ranges.len() { - crate::bail!( + bail!( "slice-assign requires input with the same rank as there are ranges {} <> {}", self_dims.len(), ranges.len() @@ -2552,18 +2558,16 @@ impl Tensor { std::ops::Bound::Excluded(v) => *v, }; if end_excluded <= start_included { - crate::bail!( - "slice-assign: empty range for dim {i}, {start_included} {end_excluded}" - ) + bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}") } if self_dims[i] < end_excluded { - crate::bail!( + bail!( "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", self_dims[i] ) } if end_excluded - start_included != src_dims[i] { - crate::bail!( + bail!( "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] ) } diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 0c4bf20e..8ae828bd 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -28,6 +28,7 @@ safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } +csv = "1.3.0" [dev-dependencies] anyhow = { workspace = true } diff --git a/candle-examples/examples/mamba-minimal/README.md b/candle-examples/examples/mamba-minimal/README.md new file mode 100644 index 00000000..0ce42123 --- /dev/null +++ b/candle-examples/examples/mamba-minimal/README.md @@ -0,0 +1,12 @@ +# candle-mamba-minimal: minimal implementation of Mamba + +This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal). + +## Running the example + +```bash +$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the" +Mamba is the most popular and best-selling game in the world. It has been downloaded more than 1,000 times by over 1 million people worldwide since its release on March 18th 2016. + +The Mamba series of games are a collection that combines elements from all genres including action, adventure, strategy & puzzle games with some unique gameplay features such as stealth and survival. The game is also known for its innovative graphics and the ability to play in a variety of different modes like single player or multiplayer. +``` diff --git a/candle-examples/examples/mamba-minimal/main.rs b/candle-examples/examples/mamba-minimal/main.rs new file mode 100644 index 00000000..5e8968c0 --- /dev/null +++ b/candle-examples/examples/mamba-minimal/main.rs @@ -0,0 +1,287 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::{Parser, ValueEnum}; + +mod model; +use model::{Config, Model}; + +use candle::{DType, Device, Module, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option<f64>, + top_p: Option<f64>, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("<|endoftext|>") { + Some(token) => token, + None => anyhow::bail!("cannot find the </s> token"), + }; + let start_gen = std::time::Instant::now(); + for _ in 0..sample_len { + let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)] +enum Which { + Mamba130m, + Mamba370m, + Mamba790m, + Mamba1_4b, + Mamba2_8b, + Mamba2_8bSlimPj, +} + +impl std::fmt::Display for Which { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl Which { + fn model_id(&self) -> &'static str { + match self { + Self::Mamba130m => "state-spaces/mamba-130m", + Self::Mamba370m => "state-spaces/mamba-370m", + Self::Mamba790m => "state-spaces/mamba-790m", + Self::Mamba1_4b => "state-spaces/mamba-1.4b", + Self::Mamba2_8b => "state-spaces/mamba-2.8b", + Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'", + } + } + + fn revision(&self) -> &'static str { + match self { + Self::Mamba130m + | Self::Mamba370m + | Self::Mamba790m + | Self::Mamba1_4b + | Self::Mamba2_8bSlimPj => "refs/pr/1", + Self::Mamba2_8b => "refs/pr/4", + } + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option<f64>, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 5000)] + sample_len: usize, + + #[arg(long, default_value = "mamba130m")] + which: Which, + + #[arg(long)] + model_id: Option<String>, + + #[arg(long)] + revision: Option<String>, + + #[arg(long)] + tokenizer_file: Option<String>, + + #[arg(long)] + weight_files: Option<String>, + + #[arg(long)] + config_file: Option<String>, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let repo = api.repo(Repo::with_revision( + args.model_id + .unwrap_or_else(|| args.which.model_id().to_string()), + RepoType::Model, + args.revision + .unwrap_or_else(|| args.which.revision().to_string()), + )); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => api + .model("EleutherAI/gpt-neox-20b".to_string()) + .get("tokenizer.json")?, + }; + let config_filename = match args.config_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("config.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::<Vec<_>>(), + None => { + vec![repo.get("model.safetensors")?] + } + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let start = std::time::Instant::now(); + let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let model = Model::new(&config, vb.pp("backbone"))?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-examples/examples/mamba-minimal/model.rs b/candle-examples/examples/mamba-minimal/model.rs new file mode 100644 index 00000000..4a0a345d --- /dev/null +++ b/candle-examples/examples/mamba-minimal/model.rs @@ -0,0 +1,204 @@ +/// This follows the lines of: +/// https://github.com/johnma2006/mamba-minimal/blob/master/model.py +/// Simple, minimal implementation of Mamba in one file of PyTorch. +use candle::{IndexOp, Module, Result, Tensor, D}; +use candle_nn::{RmsNorm, VarBuilder}; + +use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear}; + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct Config { + d_model: usize, + n_layer: usize, + vocab_size: usize, + pad_vocab_size_multiple: usize, +} + +impl Config { + fn vocab_size(&self) -> usize { + let pad = self.pad_vocab_size_multiple; + (self.vocab_size + pad - 1) / pad * pad + } + + fn dt_rank(&self) -> usize { + (self.d_model + 15) / 16 + } + + fn d_conv(&self) -> usize { + 4 + } + + fn d_state(&self) -> usize { + 16 + } + + fn d_inner(&self) -> usize { + self.d_model * 2 + } +} + +// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L177 +#[derive(Clone, Debug)] +pub struct MambaBlock { + in_proj: Linear, + conv1d: candle_nn::Conv1d, + x_proj: Linear, + dt_proj: Linear, + a_log: Tensor, + d: Tensor, + out_proj: Linear, + dt_rank: usize, +} + +impl MambaBlock { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let d_inner = cfg.d_inner(); + let d_conv = cfg.d_conv(); + let d_state = cfg.d_state(); + let dt_rank = cfg.dt_rank(); + let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?; + let conv_cfg = candle_nn::Conv1dConfig { + groups: d_inner, + padding: d_conv - 1, + ..Default::default() + }; + let conv1d = candle_nn::conv1d(d_inner, d_inner, d_conv, conv_cfg, vb.pp("conv1d"))?; + let x_proj = linear_no_bias(d_inner, dt_rank + d_state * 2, vb.pp("x_proj"))?; + let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?; + let a_log = vb.get((d_inner, d_state), "A_log")?; + let d = vb.get(d_inner, "D")?; + let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?; + Ok(Self { + in_proj, + conv1d, + x_proj, + dt_proj, + a_log, + d, + out_proj, + dt_rank, + }) + } + + fn ssm(&self, xs: &Tensor) -> Result<Tensor> { + let (_d_in, n) = self.a_log.dims2()?; + let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?; + let d = self.d.to_dtype(candle::DType::F32)?; + let x_dbl = xs.apply(&self.x_proj)?; + let delta = x_dbl.narrow(D::Minus1, 0, self.dt_rank)?; + let b = x_dbl.narrow(D::Minus1, self.dt_rank, n)?; + let c = x_dbl.narrow(D::Minus1, self.dt_rank + n, n)?; + let delta = delta.contiguous()?.apply(&self.dt_proj)?; + // softplus without threshold + let delta = (delta.exp()? + 1.)?.log()?; + let ss = selective_scan(xs, &delta, &a, &b, &c, &d)?; + Ok(ss) + } +} + +// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L275 +fn selective_scan( + u: &Tensor, + delta: &Tensor, + a: &Tensor, + b: &Tensor, + c: &Tensor, + d: &Tensor, +) -> Result<Tensor> { + let (b_sz, l, d_in) = u.dims3()?; + let n = a.dim(1)?; + let delta = delta.t()?.reshape((b_sz, d_in, l, 1))?; // b d_in l 1 + let delta_a = delta.broadcast_mul(&a.reshape((1, d_in, 1, n))?)?.exp()?; + let delta_b_u = delta + .broadcast_mul(&b.reshape((b_sz, 1, l, n))?)? + .broadcast_mul(&u.t()?.reshape((b_sz, d_in, l, 1))?)?; + let mut xs = Tensor::zeros((b_sz, d_in, n), delta_a.dtype(), delta_a.device())?; + let mut ys = Vec::with_capacity(l); + for i in 0..l { + xs = ((delta_a.i((.., .., i))? * xs)? + delta_b_u.i((.., .., i))?)?; + let y = xs.matmul(&c.i((.., i, ..))?.unsqueeze(2)?)?.squeeze(2)?; + ys.push(y) + } + let ys = Tensor::stack(ys.as_slice(), 1)?; + ys + u.broadcast_mul(d) +} + +impl Module for MambaBlock { + // https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L206 + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (_b_sz, seq_len, _dim) = xs.dims3()?; + let xs_and_res = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?; + let (xs, res) = (&xs_and_res[0], &xs_and_res[1]); + let xs = xs + .t()? + .apply(&self.conv1d)? + .narrow(D::Minus1, 0, seq_len)? + .t()?; + let xs = candle_nn::ops::silu(&xs)?; + let ys = (self.ssm(&xs)? * candle_nn::ops::silu(res))?; + ys.apply(&self.out_proj) + } +} + +// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L143 +#[derive(Clone, Debug)] +pub struct ResidualBlock { + mixer: MambaBlock, + norm: RmsNorm, +} + +impl ResidualBlock { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?; + let mixer = MambaBlock::new(cfg, vb.pp("mixer"))?; + Ok(Self { mixer, norm }) + } +} + +impl Module for ResidualBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.apply(&self.norm)?.apply(&self.mixer)? + xs + } +} + +// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56 +#[derive(Clone, Debug)] +pub struct Model { + embedding: candle_nn::Embedding, + layers: Vec<ResidualBlock>, + norm_f: RmsNorm, + lm_head: Linear, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?; + let mut layers = Vec::with_capacity(cfg.n_layer); + let vb_l = vb.pp("layers"); + for layer_idx in 0..cfg.n_layer { + let layer = ResidualBlock::new(cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?; + let lm_head = Linear::from_weights(embedding.embeddings().clone(), None); + Ok(Self { + embedding, + layers, + norm_f, + lm_head, + }) + } +} + +impl Module for Model { + fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { + let (_b_size, seq_len) = input_ids.dims2()?; + let mut xs = self.embedding.forward(input_ids)?; + for layer in self.layers.iter() { + xs = layer.forward(&xs)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm_f)? + .apply(&self.lm_head) + } +} diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 18f18e5d..2b31142e 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -155,8 +155,8 @@ struct Args { #[arg(long, short = 'n', default_value_t = 100)] sample_len: usize, - #[arg(long, default_value = "lmz/candle-mistral")] - model_id: String, + #[arg(long)] + model_id: Option<String>, #[arg(long, default_value = "main")] revision: String, @@ -207,8 +207,18 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let api = Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id, + None => { + if args.quantized { + "lmz/candle-mistral".to_string() + } else { + "mistralai/Mistral-7B-v0.1".to_string() + } + } + }; let repo = api.repo(Repo::with_revision( - args.model_id, + model_id, RepoType::Model, args.revision, )); @@ -226,8 +236,8 @@ fn main() -> Result<()> { vec![repo.get("model-q4k.gguf")?] } else { vec![ - repo.get("pytorch_model-00001-of-00002.safetensors")?, - repo.get("pytorch_model-00002-of-00002.safetensors")?, + repo.get("model-00001-of-00002.safetensors")?, + repo.get("model-00002-of-00002.safetensors")?, ] } } diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 52d453b5..3574b1f2 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -145,7 +145,10 @@ struct Args { verbose_prompt: bool, #[arg(long)] - prompt: String, + prompt: Option<String>, + + #[arg(long)] + mmlu_dir: Option<String>, /// The temperature used to generate samples. #[arg(long)] @@ -314,17 +317,105 @@ fn main() -> Result<()> { }; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new( - model, - tokenizer, - args.seed, - args.temperature, - args.top_p, - args.repeat_penalty, - args.repeat_last_n, - args.verbose_prompt, - &device, - ); - pipeline.run(&args.prompt, args.sample_len)?; + match (args.prompt, args.mmlu_dir) { + (None, None) | (Some(_), Some(_)) => { + anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified") + } + (Some(prompt), None) => { + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + args.verbose_prompt, + &device, + ); + pipeline.run(&prompt, args.sample_len)?; + } + (None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?, + } + Ok(()) +} + +fn mmlu<P: AsRef<std::path::Path>>( + mut model: Model, + tokenizer: Tokenizer, + device: &Device, + mmlu_dir: P, +) -> anyhow::Result<()> { + for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() { + let dir_entry = dir_entry.path(); + let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) { + None => "".to_string(), + Some(v) => match v.strip_suffix("_test") { + None => v.replace('_', " "), + Some(v) => v.replace('_', " "), + }, + }; + if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") { + continue; + } + println!("reading {dir_entry:?}"); + let dir_entry = std::fs::File::open(dir_entry)?; + let mut reader = csv::ReaderBuilder::new() + .has_headers(false) + .from_reader(dir_entry); + let token_a = tokenizer.token_to_id("A").unwrap(); + let token_b = tokenizer.token_to_id("B").unwrap(); + let token_c = tokenizer.token_to_id("C").unwrap(); + let token_d = tokenizer.token_to_id("D").unwrap(); + for row in reader.records() { + let row = match row { + Err(_) => continue, + Ok(row) => row, + }; + if row.len() < 5 { + continue; + } + let question = row.get(0).unwrap(); + let answer_a = row.get(1).unwrap(); + let answer_b = row.get(2).unwrap(); + let answer_c = row.get(3).unwrap(); + let answer_d = row.get(4).unwrap(); + let answer = row.get(5).unwrap(); + let prompt = format!( + "{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n", + "The following are multiple choice questions (with answers) about" + ); + let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?; + let tokens = tokens.get_ids().to_vec(); + let input = Tensor::new(tokens, device)?.unsqueeze(0)?; + let logits = match &mut model { + Model::MixFormer(m) => { + m.clear_kv_cache(); + m.forward(&input)? + } + Model::Quantized(m) => { + m.clear_kv_cache(); + m.forward(&input)? + } + }; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits_v: Vec<f32> = logits.to_vec1()?; + let pr_a = logits_v[token_a as usize]; + let pr_b = logits_v[token_b as usize]; + let pr_c = logits_v[token_c as usize]; + let pr_d = logits_v[token_d as usize]; + let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d { + "A" + } else if pr_b > pr_c && pr_b > pr_d { + "B" + } else if pr_c > pr_d { + "C" + } else { + "D" + }; + + println!("{prompt}\n -> {model_answer} vs {answer}"); + } + } Ok(()) } diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index df758b4f..bfc6de53 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -53,6 +53,8 @@ enum Which { Mistral7b, #[value(name = "7b-mistral-instruct")] Mistral7bInstruct, + #[value(name = "7b-mistral-instruct-v0.2")] + Mistral7bInstructV02, #[value(name = "7b-zephyr-a")] Zephyr7bAlpha, #[value(name = "7b-zephyr-b")] @@ -90,7 +92,8 @@ impl Which { | Self::Mixtral | Self::MixtralInstruct | Self::Mistral7b - | Self::Mistral7bInstruct => true, + | Self::Mistral7bInstruct + | Self::Mistral7bInstructV02 => true, } } @@ -111,6 +114,7 @@ impl Which { | Self::MixtralInstruct | Self::Mistral7b | Self::Mistral7bInstruct + | Self::Mistral7bInstructV02 | Self::OpenChat35 | Self::Starling7bAlpha => false, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, @@ -134,6 +138,7 @@ impl Which { | Self::MixtralInstruct | Self::Mistral7b | Self::Mistral7bInstruct + | Self::Mistral7bInstructV02 | Self::Zephyr7bAlpha | Self::Zephyr7bBeta => false, Self::OpenChat35 | Self::Starling7bAlpha => true, @@ -157,6 +162,7 @@ impl Which { Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1", Which::Mistral7b | Which::Mistral7bInstruct + | Which::Mistral7bInstructV02 | Which::Zephyr7bAlpha | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", Which::OpenChat35 => "openchat/openchat_3.5", @@ -168,7 +174,7 @@ impl Which { #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { - /// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp + /// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from llama.cpp #[arg(long)] model: Option<String>, @@ -284,6 +290,10 @@ impl Args { "TheBloke/Mistral-7B-Instruct-v0.1-GGUF", "mistral-7b-instruct-v0.1.Q4_K_S.gguf", ), + Which::Mistral7bInstructV02 => ( + "TheBloke/Mistral-7B-Instruct-v0.2-GGUF", + "mistral-7b-instruct-v0.2.Q4_K_S.gguf", + ), Which::Zephyr7bAlpha => ( "TheBloke/zephyr-7B-alpha-GGUF", "zephyr-7b-alpha.Q4_K_M.gguf", @@ -354,7 +364,7 @@ fn main() -> anyhow::Result<()> { let mut model = match model_path.extension().and_then(|v| v.to_str()) { Some("gguf") => { - let model = gguf_file::Content::read(&mut file)?; + let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; let mut total_size_in_bytes = 0; for (_, tensor) in model.tensor_infos.iter() { let elem_count = tensor.shape.elem_count(); @@ -370,7 +380,7 @@ fn main() -> anyhow::Result<()> { ModelWeights::from_gguf(model, &mut file)? } Some("ggml" | "bin") | Some(_) | None => { - let model = ggml_file::Content::read(&mut file)?; + let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; let mut total_size_in_bytes = 0; for (_, tensor) in model.tensors.iter() { let elem_count = tensor.shape().elem_count(); @@ -398,6 +408,7 @@ fn main() -> anyhow::Result<()> { | Which::MixtralInstruct | Which::Mistral7b | Which::Mistral7bInstruct + | Which::Mistral7bInstructV02 | Which::Zephyr7bAlpha | Which::Zephyr7bBeta | Which::L70b diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index caf96bce..2a66515b 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -21,6 +21,7 @@ pub struct Config { } impl Config { + // https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json pub fn config_7b_v0_1(use_flash_attn: bool) -> Self { Self { vocab_size: 32000, @@ -37,6 +38,25 @@ impl Config { use_flash_attn, } } + + // https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca/blob/main/config.json + // https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json + pub fn config_chat_ml(use_flash_attn: bool) -> Self { + Self { + vocab_size: 32002, + hidden_size: 4096, + intermediate_size: 14336, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 8, + hidden_act: Activation::Silu, + max_position_embeddings: 32768, + rms_norm_eps: 1e-5, + rope_theta: 10_000., + sliding_window: 4096, + use_flash_attn, + } + } } #[derive(Debug, Clone)] @@ -277,6 +297,10 @@ impl Attention { .reshape((b_sz, q_len, self.hidden_size))? .apply(&self.o_proj) } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } } #[derive(Debug, Clone)] @@ -320,6 +344,10 @@ impl DecoderLayer { let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; residual + xs } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } } #[derive(Debug, Clone)] @@ -403,4 +431,10 @@ impl Model { .apply(&self.norm)? .apply(&self.lm_head) } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } } diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 9e306c67..f2cb3b27 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -198,6 +198,10 @@ impl Attention { .reshape((b_sz, q_len, self.hidden_size))? .apply(&self.o_proj) } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } } #[derive(Debug, Clone)] @@ -241,6 +245,10 @@ impl DecoderLayer { let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; residual + xs } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } } #[derive(Debug, Clone)] @@ -322,4 +330,10 @@ impl Model { .apply(&self.norm)? .apply(&self.lm_head) } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } } |