summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md3
-rw-r--r--candle-core/benches/matmul.rs1
-rw-r--r--candle-core/src/quantized/gguf_file.rs2
-rw-r--r--candle-core/src/tensor.rs28
-rw-r--r--candle-examples/Cargo.toml1
-rw-r--r--candle-examples/examples/mamba-minimal/README.md12
-rw-r--r--candle-examples/examples/mamba-minimal/main.rs287
-rw-r--r--candle-examples/examples/mamba-minimal/model.rs204
-rw-r--r--candle-examples/examples/mistral/main.rs20
-rw-r--r--candle-examples/examples/phi/main.rs117
-rw-r--r--candle-examples/examples/quantized/main.rs19
-rw-r--r--candle-transformers/src/models/mistral.rs34
-rw-r--r--candle-transformers/src/models/quantized_mistral.rs14
13 files changed, 706 insertions, 36 deletions
diff --git a/README.md b/README.md
index 26a81642..9f6cf9da 100644
--- a/README.md
+++ b/README.md
@@ -65,6 +65,8 @@ We also provide a some command line based examples using state of the art models
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets.
+- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
+ implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28.
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
@@ -177,6 +179,7 @@ If you have an addition to this list, please submit a pull request.
- Falcon.
- StarCoder.
- Phi 1, 1.5, and 2.
+ - Minimal Mamba
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T.
diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs
index 8732f451..83679771 100644
--- a/candle-core/benches/matmul.rs
+++ b/candle-core/benches/matmul.rs
@@ -40,4 +40,3 @@ fn criterion_benchmark(c: &mut Criterion) {
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
-
diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs
index 1e9dc517..587ffc0f 100644
--- a/candle-core/src/quantized/gguf_file.rs
+++ b/candle-core/src/quantized/gguf_file.rs
@@ -41,7 +41,7 @@ impl VersionedMagic {
(Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3,
- _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
+ _ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index f15f8c1c..54f9fa2b 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -396,7 +396,7 @@ impl Tensor {
device: &Device,
) -> Result<Self> {
if D::is_zero(&step) {
- crate::bail!("step cannot be zero")
+ bail!("step cannot be zero")
}
let mut data = vec![];
let mut current = start;
@@ -1041,6 +1041,9 @@ impl Tensor {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
+ if h < kernel_size.0 || w < kernel_size.1 {
+ bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
+ }
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
@@ -1076,6 +1079,9 @@ impl Tensor {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
+ if h < kernel_size.0 || w < kernel_size.1 {
+ bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
+ }
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
@@ -1798,7 +1804,7 @@ impl Tensor {
let is_permutation =
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
if !is_permutation {
- crate::bail!(
+ bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(),
dims
@@ -2293,7 +2299,7 @@ impl Tensor {
if left == 0 && right == 0 {
Ok(self.clone())
} else if self.elem_count() == 0 {
- crate::bail!("cannot use pad_with_same on an empty tensor")
+ bail!("cannot use pad_with_same on an empty tensor")
} else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
@@ -2457,13 +2463,13 @@ impl Tensor {
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64;
if rank <= axis {
- crate::bail!("axis {axis} is too large, tensor rank {rank}")
+ bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis {
Ok(axis as usize)
} else {
let naxis = rank + axis;
if naxis < 0 {
- crate::bail!("axis {axis} is too small, tensor rank {rank}")
+ bail!("axis {axis} is too small, tensor rank {rank}")
}
Ok(naxis as usize)
}
@@ -2525,14 +2531,14 @@ impl Tensor {
let src_dims = src.dims();
let self_dims = self.dims();
if self_dims.len() != src_dims.len() {
- crate::bail!(
+ bail!(
"slice-assign requires input with the same rank {} <> {}",
self_dims.len(),
src_dims.len()
)
}
if self_dims.len() != ranges.len() {
- crate::bail!(
+ bail!(
"slice-assign requires input with the same rank as there are ranges {} <> {}",
self_dims.len(),
ranges.len()
@@ -2552,18 +2558,16 @@ impl Tensor {
std::ops::Bound::Excluded(v) => *v,
};
if end_excluded <= start_included {
- crate::bail!(
- "slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
- )
+ bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
}
if self_dims[i] < end_excluded {
- crate::bail!(
+ bail!(
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
self_dims[i]
)
}
if end_excluded - start_included != src_dims[i] {
- crate::bail!(
+ bail!(
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
)
}
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 0c4bf20e..8ae828bd 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -28,6 +28,7 @@ safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
+csv = "1.3.0"
[dev-dependencies]
anyhow = { workspace = true }
diff --git a/candle-examples/examples/mamba-minimal/README.md b/candle-examples/examples/mamba-minimal/README.md
new file mode 100644
index 00000000..0ce42123
--- /dev/null
+++ b/candle-examples/examples/mamba-minimal/README.md
@@ -0,0 +1,12 @@
+# candle-mamba-minimal: minimal implementation of Mamba
+
+This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
+
+## Running the example
+
+```bash
+$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
+Mamba is the most popular and best-selling game in the world. It has been downloaded more than 1,000 times by over 1 million people worldwide since its release on March 18th 2016.
+
+The Mamba series of games are a collection that combines elements from all genres including action, adventure, strategy & puzzle games with some unique gameplay features such as stealth and survival. The game is also known for its innovative graphics and the ability to play in a variety of different modes like single player or multiplayer.
+```
diff --git a/candle-examples/examples/mamba-minimal/main.rs b/candle-examples/examples/mamba-minimal/main.rs
new file mode 100644
index 00000000..5e8968c0
--- /dev/null
+++ b/candle-examples/examples/mamba-minimal/main.rs
@@ -0,0 +1,287 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::{Error as E, Result};
+use clap::{Parser, ValueEnum};
+
+mod model;
+use model::{Config, Model};
+
+use candle::{DType, Device, Module, Tensor};
+use candle_examples::token_output_stream::TokenOutputStream;
+use candle_nn::VarBuilder;
+use candle_transformers::generation::LogitsProcessor;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::Tokenizer;
+
+struct TextGeneration {
+ model: Model,
+ device: Device,
+ tokenizer: TokenOutputStream,
+ logits_processor: LogitsProcessor,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+}
+
+impl TextGeneration {
+ #[allow(clippy::too_many_arguments)]
+ fn new(
+ model: Model,
+ tokenizer: Tokenizer,
+ seed: u64,
+ temp: Option<f64>,
+ top_p: Option<f64>,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+ device: &Device,
+ ) -> Self {
+ let logits_processor = LogitsProcessor::new(seed, temp, top_p);
+ Self {
+ model,
+ tokenizer: TokenOutputStream::new(tokenizer),
+ logits_processor,
+ repeat_penalty,
+ repeat_last_n,
+ device: device.clone(),
+ }
+ }
+
+ fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
+ use std::io::Write;
+ self.tokenizer.clear();
+ let mut tokens = self
+ .tokenizer
+ .tokenizer()
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ for &t in tokens.iter() {
+ if let Some(t) = self.tokenizer.next_token(t)? {
+ print!("{t}")
+ }
+ }
+ std::io::stdout().flush()?;
+
+ let mut generated_tokens = 0usize;
+ let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
+ Some(token) => token,
+ None => anyhow::bail!("cannot find the </s> token"),
+ };
+ let start_gen = std::time::Instant::now();
+ for _ in 0..sample_len {
+ let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
+ let logits = self.model.forward(&input)?;
+ let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
+ let logits = if self.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = tokens.len().saturating_sub(self.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ self.repeat_penalty,
+ &tokens[start_at..],
+ )?
+ };
+
+ let next_token = self.logits_processor.sample(&logits)?;
+ tokens.push(next_token);
+ generated_tokens += 1;
+ if next_token == eos_token {
+ break;
+ }
+ if let Some(t) = self.tokenizer.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+ }
+ let dt = start_gen.elapsed();
+ if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
+ print!("{rest}");
+ }
+ std::io::stdout().flush()?;
+ println!(
+ "\n{generated_tokens} tokens generated ({:.2} token/s)",
+ generated_tokens as f64 / dt.as_secs_f64(),
+ );
+ Ok(())
+ }
+}
+
+#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
+enum Which {
+ Mamba130m,
+ Mamba370m,
+ Mamba790m,
+ Mamba1_4b,
+ Mamba2_8b,
+ Mamba2_8bSlimPj,
+}
+
+impl std::fmt::Display for Which {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{:?}", self)
+ }
+}
+
+impl Which {
+ fn model_id(&self) -> &'static str {
+ match self {
+ Self::Mamba130m => "state-spaces/mamba-130m",
+ Self::Mamba370m => "state-spaces/mamba-370m",
+ Self::Mamba790m => "state-spaces/mamba-790m",
+ Self::Mamba1_4b => "state-spaces/mamba-1.4b",
+ Self::Mamba2_8b => "state-spaces/mamba-2.8b",
+ Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
+ }
+ }
+
+ fn revision(&self) -> &'static str {
+ match self {
+ Self::Mamba130m
+ | Self::Mamba370m
+ | Self::Mamba790m
+ | Self::Mamba1_4b
+ | Self::Mamba2_8bSlimPj => "refs/pr/1",
+ Self::Mamba2_8b => "refs/pr/4",
+ }
+ }
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ #[arg(long)]
+ prompt: String,
+
+ /// The temperature used to generate samples.
+ #[arg(long)]
+ temperature: Option<f64>,
+
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 299792458)]
+ seed: u64,
+
+ /// The length of the sample to generate (in tokens).
+ #[arg(long, short = 'n', default_value_t = 5000)]
+ sample_len: usize,
+
+ #[arg(long, default_value = "mamba130m")]
+ which: Which,
+
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long)]
+ revision: Option<String>,
+
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ #[arg(long)]
+ config_file: Option<String>,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+ println!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle::utils::with_avx(),
+ candle::utils::with_neon(),
+ candle::utils::with_simd128(),
+ candle::utils::with_f16c()
+ );
+ println!(
+ "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
+ args.temperature.unwrap_or(0.),
+ args.repeat_penalty,
+ args.repeat_last_n
+ );
+
+ let start = std::time::Instant::now();
+ let api = Api::new()?;
+ let repo = api.repo(Repo::with_revision(
+ args.model_id
+ .unwrap_or_else(|| args.which.model_id().to_string()),
+ RepoType::Model,
+ args.revision
+ .unwrap_or_else(|| args.which.revision().to_string()),
+ ));
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => api
+ .model("EleutherAI/gpt-neox-20b".to_string())
+ .get("tokenizer.json")?,
+ };
+ let config_filename = match args.config_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("config.json")?,
+ };
+ let filenames = match args.weight_files {
+ Some(files) => files
+ .split(',')
+ .map(std::path::PathBuf::from)
+ .collect::<Vec<_>>(),
+ None => {
+ vec![repo.get("model.safetensors")?]
+ }
+ };
+ println!("retrieved the files in {:?}", start.elapsed());
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let start = std::time::Instant::now();
+ let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
+ let device = candle_examples::device(args.cpu)?;
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
+ let model = Model::new(&config, vb.pp("backbone"))?;
+ println!("loaded the model in {:?}", start.elapsed());
+
+ let mut pipeline = TextGeneration::new(
+ model,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ args.repeat_penalty,
+ args.repeat_last_n,
+ &device,
+ );
+ pipeline.run(&args.prompt, args.sample_len)?;
+ Ok(())
+}
diff --git a/candle-examples/examples/mamba-minimal/model.rs b/candle-examples/examples/mamba-minimal/model.rs
new file mode 100644
index 00000000..4a0a345d
--- /dev/null
+++ b/candle-examples/examples/mamba-minimal/model.rs
@@ -0,0 +1,204 @@
+/// This follows the lines of:
+/// https://github.com/johnma2006/mamba-minimal/blob/master/model.py
+/// Simple, minimal implementation of Mamba in one file of PyTorch.
+use candle::{IndexOp, Module, Result, Tensor, D};
+use candle_nn::{RmsNorm, VarBuilder};
+
+use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear};
+
+#[derive(Debug, Clone, serde::Deserialize)]
+pub struct Config {
+ d_model: usize,
+ n_layer: usize,
+ vocab_size: usize,
+ pad_vocab_size_multiple: usize,
+}
+
+impl Config {
+ fn vocab_size(&self) -> usize {
+ let pad = self.pad_vocab_size_multiple;
+ (self.vocab_size + pad - 1) / pad * pad
+ }
+
+ fn dt_rank(&self) -> usize {
+ (self.d_model + 15) / 16
+ }
+
+ fn d_conv(&self) -> usize {
+ 4
+ }
+
+ fn d_state(&self) -> usize {
+ 16
+ }
+
+ fn d_inner(&self) -> usize {
+ self.d_model * 2
+ }
+}
+
+// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L177
+#[derive(Clone, Debug)]
+pub struct MambaBlock {
+ in_proj: Linear,
+ conv1d: candle_nn::Conv1d,
+ x_proj: Linear,
+ dt_proj: Linear,
+ a_log: Tensor,
+ d: Tensor,
+ out_proj: Linear,
+ dt_rank: usize,
+}
+
+impl MambaBlock {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let d_inner = cfg.d_inner();
+ let d_conv = cfg.d_conv();
+ let d_state = cfg.d_state();
+ let dt_rank = cfg.dt_rank();
+ let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
+ let conv_cfg = candle_nn::Conv1dConfig {
+ groups: d_inner,
+ padding: d_conv - 1,
+ ..Default::default()
+ };
+ let conv1d = candle_nn::conv1d(d_inner, d_inner, d_conv, conv_cfg, vb.pp("conv1d"))?;
+ let x_proj = linear_no_bias(d_inner, dt_rank + d_state * 2, vb.pp("x_proj"))?;
+ let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
+ let a_log = vb.get((d_inner, d_state), "A_log")?;
+ let d = vb.get(d_inner, "D")?;
+ let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
+ Ok(Self {
+ in_proj,
+ conv1d,
+ x_proj,
+ dt_proj,
+ a_log,
+ d,
+ out_proj,
+ dt_rank,
+ })
+ }
+
+ fn ssm(&self, xs: &Tensor) -> Result<Tensor> {
+ let (_d_in, n) = self.a_log.dims2()?;
+ let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
+ let d = self.d.to_dtype(candle::DType::F32)?;
+ let x_dbl = xs.apply(&self.x_proj)?;
+ let delta = x_dbl.narrow(D::Minus1, 0, self.dt_rank)?;
+ let b = x_dbl.narrow(D::Minus1, self.dt_rank, n)?;
+ let c = x_dbl.narrow(D::Minus1, self.dt_rank + n, n)?;
+ let delta = delta.contiguous()?.apply(&self.dt_proj)?;
+ // softplus without threshold
+ let delta = (delta.exp()? + 1.)?.log()?;
+ let ss = selective_scan(xs, &delta, &a, &b, &c, &d)?;
+ Ok(ss)
+ }
+}
+
+// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L275
+fn selective_scan(
+ u: &Tensor,
+ delta: &Tensor,
+ a: &Tensor,
+ b: &Tensor,
+ c: &Tensor,
+ d: &Tensor,
+) -> Result<Tensor> {
+ let (b_sz, l, d_in) = u.dims3()?;
+ let n = a.dim(1)?;
+ let delta = delta.t()?.reshape((b_sz, d_in, l, 1))?; // b d_in l 1
+ let delta_a = delta.broadcast_mul(&a.reshape((1, d_in, 1, n))?)?.exp()?;
+ let delta_b_u = delta
+ .broadcast_mul(&b.reshape((b_sz, 1, l, n))?)?
+ .broadcast_mul(&u.t()?.reshape((b_sz, d_in, l, 1))?)?;
+ let mut xs = Tensor::zeros((b_sz, d_in, n), delta_a.dtype(), delta_a.device())?;
+ let mut ys = Vec::with_capacity(l);
+ for i in 0..l {
+ xs = ((delta_a.i((.., .., i))? * xs)? + delta_b_u.i((.., .., i))?)?;
+ let y = xs.matmul(&c.i((.., i, ..))?.unsqueeze(2)?)?.squeeze(2)?;
+ ys.push(y)
+ }
+ let ys = Tensor::stack(ys.as_slice(), 1)?;
+ ys + u.broadcast_mul(d)
+}
+
+impl Module for MambaBlock {
+ // https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L206
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (_b_sz, seq_len, _dim) = xs.dims3()?;
+ let xs_and_res = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
+ let (xs, res) = (&xs_and_res[0], &xs_and_res[1]);
+ let xs = xs
+ .t()?
+ .apply(&self.conv1d)?
+ .narrow(D::Minus1, 0, seq_len)?
+ .t()?;
+ let xs = candle_nn::ops::silu(&xs)?;
+ let ys = (self.ssm(&xs)? * candle_nn::ops::silu(res))?;
+ ys.apply(&self.out_proj)
+ }
+}
+
+// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L143
+#[derive(Clone, Debug)]
+pub struct ResidualBlock {
+ mixer: MambaBlock,
+ norm: RmsNorm,
+}
+
+impl ResidualBlock {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
+ let mixer = MambaBlock::new(cfg, vb.pp("mixer"))?;
+ Ok(Self { mixer, norm })
+ }
+}
+
+impl Module for ResidualBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.norm)?.apply(&self.mixer)? + xs
+ }
+}
+
+// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
+#[derive(Clone, Debug)]
+pub struct Model {
+ embedding: candle_nn::Embedding,
+ layers: Vec<ResidualBlock>,
+ norm_f: RmsNorm,
+ lm_head: Linear,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
+ let mut layers = Vec::with_capacity(cfg.n_layer);
+ let vb_l = vb.pp("layers");
+ for layer_idx in 0..cfg.n_layer {
+ let layer = ResidualBlock::new(cfg, vb_l.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
+ let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
+ Ok(Self {
+ embedding,
+ layers,
+ norm_f,
+ lm_head,
+ })
+ }
+}
+
+impl Module for Model {
+ fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
+ let (_b_size, seq_len) = input_ids.dims2()?;
+ let mut xs = self.embedding.forward(input_ids)?;
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs)?
+ }
+ xs.narrow(1, seq_len - 1, 1)?
+ .apply(&self.norm_f)?
+ .apply(&self.lm_head)
+ }
+}
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index 18f18e5d..2b31142e 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -155,8 +155,8 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
- #[arg(long, default_value = "lmz/candle-mistral")]
- model_id: String,
+ #[arg(long)]
+ model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
@@ -207,8 +207,18 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
+ let model_id = match args.model_id {
+ Some(model_id) => model_id,
+ None => {
+ if args.quantized {
+ "lmz/candle-mistral".to_string()
+ } else {
+ "mistralai/Mistral-7B-v0.1".to_string()
+ }
+ }
+ };
let repo = api.repo(Repo::with_revision(
- args.model_id,
+ model_id,
RepoType::Model,
args.revision,
));
@@ -226,8 +236,8 @@ fn main() -> Result<()> {
vec![repo.get("model-q4k.gguf")?]
} else {
vec![
- repo.get("pytorch_model-00001-of-00002.safetensors")?,
- repo.get("pytorch_model-00002-of-00002.safetensors")?,
+ repo.get("model-00001-of-00002.safetensors")?,
+ repo.get("model-00002-of-00002.safetensors")?,
]
}
}
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 52d453b5..3574b1f2 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -145,7 +145,10 @@ struct Args {
verbose_prompt: bool,
#[arg(long)]
- prompt: String,
+ prompt: Option<String>,
+
+ #[arg(long)]
+ mmlu_dir: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
@@ -314,17 +317,105 @@ fn main() -> Result<()> {
};
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(
- model,
- tokenizer,
- args.seed,
- args.temperature,
- args.top_p,
- args.repeat_penalty,
- args.repeat_last_n,
- args.verbose_prompt,
- &device,
- );
- pipeline.run(&args.prompt, args.sample_len)?;
+ match (args.prompt, args.mmlu_dir) {
+ (None, None) | (Some(_), Some(_)) => {
+ anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified")
+ }
+ (Some(prompt), None) => {
+ let mut pipeline = TextGeneration::new(
+ model,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ args.repeat_penalty,
+ args.repeat_last_n,
+ args.verbose_prompt,
+ &device,
+ );
+ pipeline.run(&prompt, args.sample_len)?;
+ }
+ (None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?,
+ }
+ Ok(())
+}
+
+fn mmlu<P: AsRef<std::path::Path>>(
+ mut model: Model,
+ tokenizer: Tokenizer,
+ device: &Device,
+ mmlu_dir: P,
+) -> anyhow::Result<()> {
+ for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() {
+ let dir_entry = dir_entry.path();
+ let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) {
+ None => "".to_string(),
+ Some(v) => match v.strip_suffix("_test") {
+ None => v.replace('_', " "),
+ Some(v) => v.replace('_', " "),
+ },
+ };
+ if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") {
+ continue;
+ }
+ println!("reading {dir_entry:?}");
+ let dir_entry = std::fs::File::open(dir_entry)?;
+ let mut reader = csv::ReaderBuilder::new()
+ .has_headers(false)
+ .from_reader(dir_entry);
+ let token_a = tokenizer.token_to_id("A").unwrap();
+ let token_b = tokenizer.token_to_id("B").unwrap();
+ let token_c = tokenizer.token_to_id("C").unwrap();
+ let token_d = tokenizer.token_to_id("D").unwrap();
+ for row in reader.records() {
+ let row = match row {
+ Err(_) => continue,
+ Ok(row) => row,
+ };
+ if row.len() < 5 {
+ continue;
+ }
+ let question = row.get(0).unwrap();
+ let answer_a = row.get(1).unwrap();
+ let answer_b = row.get(2).unwrap();
+ let answer_c = row.get(3).unwrap();
+ let answer_d = row.get(4).unwrap();
+ let answer = row.get(5).unwrap();
+ let prompt = format!(
+ "{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n",
+ "The following are multiple choice questions (with answers) about"
+ );
+ let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?;
+ let tokens = tokens.get_ids().to_vec();
+ let input = Tensor::new(tokens, device)?.unsqueeze(0)?;
+ let logits = match &mut model {
+ Model::MixFormer(m) => {
+ m.clear_kv_cache();
+ m.forward(&input)?
+ }
+ Model::Quantized(m) => {
+ m.clear_kv_cache();
+ m.forward(&input)?
+ }
+ };
+ let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
+ let logits_v: Vec<f32> = logits.to_vec1()?;
+ let pr_a = logits_v[token_a as usize];
+ let pr_b = logits_v[token_b as usize];
+ let pr_c = logits_v[token_c as usize];
+ let pr_d = logits_v[token_d as usize];
+ let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d {
+ "A"
+ } else if pr_b > pr_c && pr_b > pr_d {
+ "B"
+ } else if pr_c > pr_d {
+ "C"
+ } else {
+ "D"
+ };
+
+ println!("{prompt}\n -> {model_answer} vs {answer}");
+ }
+ }
Ok(())
}
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index df758b4f..bfc6de53 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -53,6 +53,8 @@ enum Which {
Mistral7b,
#[value(name = "7b-mistral-instruct")]
Mistral7bInstruct,
+ #[value(name = "7b-mistral-instruct-v0.2")]
+ Mistral7bInstructV02,
#[value(name = "7b-zephyr-a")]
Zephyr7bAlpha,
#[value(name = "7b-zephyr-b")]
@@ -90,7 +92,8 @@ impl Which {
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b
- | Self::Mistral7bInstruct => true,
+ | Self::Mistral7bInstruct
+ | Self::Mistral7bInstructV02 => true,
}
}
@@ -111,6 +114,7 @@ impl Which {
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
+ | Self::Mistral7bInstructV02
| Self::OpenChat35
| Self::Starling7bAlpha => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
@@ -134,6 +138,7 @@ impl Which {
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
+ | Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
@@ -157,6 +162,7 @@ impl Which {
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Which::Mistral7b
| Which::Mistral7bInstruct
+ | Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => "openchat/openchat_3.5",
@@ -168,7 +174,7 @@ impl Which {
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
- /// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
+ /// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from llama.cpp
#[arg(long)]
model: Option<String>,
@@ -284,6 +290,10 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
),
+ Which::Mistral7bInstructV02 => (
+ "TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
+ "mistral-7b-instruct-v0.2.Q4_K_S.gguf",
+ ),
Which::Zephyr7bAlpha => (
"TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf",
@@ -354,7 +364,7 @@ fn main() -> anyhow::Result<()> {
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
- let model = gguf_file::Content::read(&mut file)?;
+ let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
@@ -370,7 +380,7 @@ fn main() -> anyhow::Result<()> {
ModelWeights::from_gguf(model, &mut file)?
}
Some("ggml" | "bin") | Some(_) | None => {
- let model = ggml_file::Content::read(&mut file)?;
+ let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
@@ -398,6 +408,7 @@ fn main() -> anyhow::Result<()> {
| Which::MixtralInstruct
| Which::Mistral7b
| Which::Mistral7bInstruct
+ | Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::L70b
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index caf96bce..2a66515b 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -21,6 +21,7 @@ pub struct Config {
}
impl Config {
+ // https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
pub fn config_7b_v0_1(use_flash_attn: bool) -> Self {
Self {
vocab_size: 32000,
@@ -37,6 +38,25 @@ impl Config {
use_flash_attn,
}
}
+
+ // https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca/blob/main/config.json
+ // https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
+ pub fn config_chat_ml(use_flash_attn: bool) -> Self {
+ Self {
+ vocab_size: 32002,
+ hidden_size: 4096,
+ intermediate_size: 14336,
+ num_hidden_layers: 32,
+ num_attention_heads: 32,
+ num_key_value_heads: 8,
+ hidden_act: Activation::Silu,
+ max_position_embeddings: 32768,
+ rms_norm_eps: 1e-5,
+ rope_theta: 10_000.,
+ sliding_window: 4096,
+ use_flash_attn,
+ }
+ }
}
#[derive(Debug, Clone)]
@@ -277,6 +297,10 @@ impl Attention {
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
}
#[derive(Debug, Clone)]
@@ -320,6 +344,10 @@ impl DecoderLayer {
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache()
+ }
}
#[derive(Debug, Clone)]
@@ -403,4 +431,10 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
+
+ pub fn clear_kv_cache(&mut self) {
+ for layer in self.layers.iter_mut() {
+ layer.clear_kv_cache()
+ }
+ }
}
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs
index 9e306c67..f2cb3b27 100644
--- a/candle-transformers/src/models/quantized_mistral.rs
+++ b/candle-transformers/src/models/quantized_mistral.rs
@@ -198,6 +198,10 @@ impl Attention {
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
}
#[derive(Debug, Clone)]
@@ -241,6 +245,10 @@ impl DecoderLayer {
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache()
+ }
}
#[derive(Debug, Clone)]
@@ -322,4 +330,10 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
+
+ pub fn clear_kv_cache(&mut self) {
+ for layer in self.layers.iter_mut() {
+ layer.clear_kv_cache()
+ }
+ }
}