summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md4
-rw-r--r--candle-examples/examples/mamba-minimal/README.md3
-rw-r--r--candle-examples/examples/mamba/README.md17
-rw-r--r--candle-examples/examples/mamba/main.rs299
-rw-r--r--candle-transformers/src/models/mamba.rs211
-rw-r--r--candle-transformers/src/models/mod.rs1
6 files changed, 533 insertions, 2 deletions
diff --git a/README.md b/README.md
index 90344b34..9bfa30d8 100644
--- a/README.md
+++ b/README.md
@@ -67,7 +67,7 @@ We also provide a some command line based examples using state of the art models
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
-- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
+- [Mamba](./candle-examples/examples/mamba/): an inference only
implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28.
@@ -186,7 +186,7 @@ If you have an addition to this list, please submit a pull request.
- Falcon.
- StarCoder.
- Phi 1, 1.5, and 2.
- - Minimal Mamba
+ - Mamba, Minimal Mamba
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
diff --git a/candle-examples/examples/mamba-minimal/README.md b/candle-examples/examples/mamba-minimal/README.md
index 0ce42123..46479828 100644
--- a/candle-examples/examples/mamba-minimal/README.md
+++ b/candle-examples/examples/mamba-minimal/README.md
@@ -2,6 +2,9 @@
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
+Compared to the mamba example, this version can handle training but is much
+slower.
+
## Running the example
```bash
diff --git a/candle-examples/examples/mamba/README.md b/candle-examples/examples/mamba/README.md
new file mode 100644
index 00000000..507434a1
--- /dev/null
+++ b/candle-examples/examples/mamba/README.md
@@ -0,0 +1,17 @@
+# candle-mamba: Mamba implementation
+
+Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
+the transformer architecture. It leverages State Space Models (SSMs) with the
+goal of being computationally efficient on long sequences. The implementation is
+based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
+
+- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
+
+Compared to the mamba-minimal example, this version is far more efficient but
+would only work for inference.
+## Running the example
+
+```bash
+$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
+```
+
diff --git a/candle-examples/examples/mamba/main.rs b/candle-examples/examples/mamba/main.rs
new file mode 100644
index 00000000..4802f960
--- /dev/null
+++ b/candle-examples/examples/mamba/main.rs
@@ -0,0 +1,299 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::{Error as E, Result};
+use clap::{Parser, ValueEnum};
+
+use candle_transformers::models::mamba::{Config, Model, State};
+
+use candle::{DType, Device, Tensor};
+use candle_examples::token_output_stream::TokenOutputStream;
+use candle_nn::VarBuilder;
+use candle_transformers::generation::LogitsProcessor;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::Tokenizer;
+
+struct TextGeneration {
+ model: Model,
+ config: Config,
+ device: Device,
+ tokenizer: TokenOutputStream,
+ logits_processor: LogitsProcessor,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+}
+
+impl TextGeneration {
+ #[allow(clippy::too_many_arguments)]
+ fn new(
+ model: Model,
+ config: Config,
+ tokenizer: Tokenizer,
+ seed: u64,
+ temp: Option<f64>,
+ top_p: Option<f64>,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+ device: &Device,
+ ) -> Self {
+ let logits_processor = LogitsProcessor::new(seed, temp, top_p);
+ Self {
+ model,
+ config,
+ tokenizer: TokenOutputStream::new(tokenizer),
+ logits_processor,
+ repeat_penalty,
+ repeat_last_n,
+ device: device.clone(),
+ }
+ }
+
+ fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
+ use std::io::Write;
+ self.tokenizer.clear();
+ let mut tokens = self
+ .tokenizer
+ .tokenizer()
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let mut generated_tokens = 0usize;
+ let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
+ Some(token) => token,
+ None => anyhow::bail!("cannot find the </s> token"),
+ };
+ let mut state = State::new(1, &self.config, &self.device)?;
+ let mut next_logits = None;
+ for &t in tokens.iter() {
+ let input = Tensor::new(&[t], &self.device)?;
+ let logits = self.model.forward(&input, &mut state)?;
+ next_logits = Some(logits);
+ if let Some(t) = self.tokenizer.next_token(t)? {
+ print!("{t}")
+ }
+ }
+ std::io::stdout().flush()?;
+
+ let start_gen = std::time::Instant::now();
+ for _ in 0..sample_len {
+ let logits = match next_logits.as_ref() {
+ Some(logits) => logits,
+ None => anyhow::bail!("cannot work on an empty prompt"),
+ };
+ let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
+ let logits = if self.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = tokens.len().saturating_sub(self.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ self.repeat_penalty,
+ &tokens[start_at..],
+ )?
+ };
+ let next_token = self.logits_processor.sample(&logits)?;
+ tokens.push(next_token);
+ generated_tokens += 1;
+ if next_token == eos_token {
+ break;
+ }
+ if let Some(t) = self.tokenizer.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+
+ let input = Tensor::new(&[next_token], &self.device)?;
+ next_logits = Some(self.model.forward(&input, &mut state)?)
+ }
+ let dt = start_gen.elapsed();
+ if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
+ print!("{rest}");
+ }
+ std::io::stdout().flush()?;
+ println!(
+ "\n{generated_tokens} tokens generated ({:.2} token/s)",
+ generated_tokens as f64 / dt.as_secs_f64(),
+ );
+ Ok(())
+ }
+}
+
+#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
+enum Which {
+ Mamba130m,
+ Mamba370m,
+ Mamba790m,
+ Mamba1_4b,
+ Mamba2_8b,
+ Mamba2_8bSlimPj,
+}
+
+impl std::fmt::Display for Which {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{:?}", self)
+ }
+}
+
+impl Which {
+ fn model_id(&self) -> &'static str {
+ match self {
+ Self::Mamba130m => "state-spaces/mamba-130m",
+ Self::Mamba370m => "state-spaces/mamba-370m",
+ Self::Mamba790m => "state-spaces/mamba-790m",
+ Self::Mamba1_4b => "state-spaces/mamba-1.4b",
+ Self::Mamba2_8b => "state-spaces/mamba-2.8b",
+ Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
+ }
+ }
+
+ fn revision(&self) -> &'static str {
+ match self {
+ Self::Mamba130m
+ | Self::Mamba370m
+ | Self::Mamba790m
+ | Self::Mamba1_4b
+ | Self::Mamba2_8bSlimPj => "refs/pr/1",
+ Self::Mamba2_8b => "refs/pr/4",
+ }
+ }
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ #[arg(long)]
+ prompt: String,
+
+ /// The temperature used to generate samples.
+ #[arg(long)]
+ temperature: Option<f64>,
+
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 299792458)]
+ seed: u64,
+
+ /// The length of the sample to generate (in tokens).
+ #[arg(long, short = 'n', default_value_t = 5000)]
+ sample_len: usize,
+
+ #[arg(long, default_value = "mamba130m")]
+ which: Which,
+
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long)]
+ revision: Option<String>,
+
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ #[arg(long)]
+ config_file: Option<String>,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+ println!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle::utils::with_avx(),
+ candle::utils::with_neon(),
+ candle::utils::with_simd128(),
+ candle::utils::with_f16c()
+ );
+ println!(
+ "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
+ args.temperature.unwrap_or(0.),
+ args.repeat_penalty,
+ args.repeat_last_n
+ );
+
+ let start = std::time::Instant::now();
+ let api = Api::new()?;
+ let repo = api.repo(Repo::with_revision(
+ args.model_id
+ .unwrap_or_else(|| args.which.model_id().to_string()),
+ RepoType::Model,
+ args.revision
+ .unwrap_or_else(|| args.which.revision().to_string()),
+ ));
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => api
+ .model("EleutherAI/gpt-neox-20b".to_string())
+ .get("tokenizer.json")?,
+ };
+ let config_filename = match args.config_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("config.json")?,
+ };
+ let filenames = match args.weight_files {
+ Some(files) => files
+ .split(',')
+ .map(std::path::PathBuf::from)
+ .collect::<Vec<_>>(),
+ None => {
+ vec![repo.get("model.safetensors")?]
+ }
+ };
+ println!("retrieved the files in {:?}", start.elapsed());
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let start = std::time::Instant::now();
+ let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
+ let device = candle_examples::device(args.cpu)?;
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
+ let model = Model::new(&config, vb.pp("backbone"))?;
+ println!("loaded the model in {:?}", start.elapsed());
+
+ let mut pipeline = TextGeneration::new(
+ model,
+ config,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ args.repeat_penalty,
+ args.repeat_last_n,
+ &device,
+ );
+ pipeline.run(&args.prompt, args.sample_len)?;
+ Ok(())
+}
diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs
new file mode 100644
index 00000000..da254bd1
--- /dev/null
+++ b/candle-transformers/src/models/mamba.rs
@@ -0,0 +1,211 @@
+#![allow(unused)]
+/// A fast implementation of mamba for inference only.
+/// This is based on: https://github.com/LaurentMazare/mamba.rs
+use crate::models::with_tracing::{linear, linear_no_bias, Linear};
+use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
+use candle_nn::{RmsNorm, VarBuilder};
+
+const D_CONV: usize = 4;
+const D_STATE: usize = 16;
+
+#[derive(Debug, Clone, serde::Deserialize)]
+pub struct Config {
+ d_model: usize,
+ n_layer: usize,
+ vocab_size: usize,
+ pad_vocab_size_multiple: usize,
+}
+
+impl Config {
+ fn vocab_size(&self) -> usize {
+ let pad = self.pad_vocab_size_multiple;
+ (self.vocab_size + pad - 1) / pad * pad
+ }
+
+ fn dt_rank(&self) -> usize {
+ (self.d_model + 15) / 16
+ }
+
+ fn d_inner(&self) -> usize {
+ self.d_model * 2
+ }
+}
+
+pub struct State {
+ hs: Vec<Tensor>,
+ prev_xs: Vec<[Tensor; D_CONV]>,
+ pos: usize,
+}
+
+impl State {
+ pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
+ let mut hs = Vec::with_capacity(cfg.n_layer);
+ let mut prev_xs = Vec::with_capacity(cfg.n_layer);
+ for _i in 0..cfg.n_layer {
+ let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
+ let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
+ hs.push(h);
+ prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
+ }
+ Ok(Self {
+ hs,
+ prev_xs,
+ pos: 0,
+ })
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct MambaBlock {
+ in_proj: Linear,
+ conv1d_bias: Tensor,
+ conv1d_weights: [Tensor; D_CONV],
+ x_proj: Linear,
+ dt_proj: Linear,
+ a_log: Tensor,
+ d: Tensor,
+ out_proj: Linear,
+ dt_rank: usize,
+ layer_index: usize,
+ d_inner: usize,
+}
+
+impl MambaBlock {
+ pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let d_inner = cfg.d_inner();
+ let dt_rank = cfg.dt_rank();
+ let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
+ let x_proj = linear_no_bias(d_inner, dt_rank + D_STATE * 2, vb.pp("x_proj"))?;
+ let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
+ let a_log = vb.get((d_inner, D_STATE), "A_log")?;
+ let d = vb.get(d_inner, "D")?;
+ let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
+ let conv1d_bias = vb.get(d_inner, "conv1d.bias")?;
+ let conv1d_weight = vb.get((d_inner, 1, D_CONV), "conv1d.weight")?;
+ let conv1d_weights = [
+ conv1d_weight.i((.., 0, 0))?,
+ conv1d_weight.i((.., 0, 1))?,
+ conv1d_weight.i((.., 0, 2))?,
+ conv1d_weight.i((.., 0, 3))?,
+ ];
+ Ok(Self {
+ in_proj,
+ conv1d_bias,
+ conv1d_weights,
+ x_proj,
+ dt_proj,
+ a_log,
+ d,
+ out_proj,
+ dt_rank,
+ layer_index,
+ d_inner,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
+ let (b_sz, _dim) = xs.dims2()?;
+ let li = self.layer_index;
+ let mut xs = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
+ let proj_for_silu = xs.remove(1);
+ state.prev_xs[li][state.pos % D_CONV] = xs.remove(0);
+ let mut proj_for_conv = self.conv1d_bias.broadcast_as((b_sz, self.d_inner))?;
+ for d_c in 0..D_CONV {
+ proj_for_conv = (proj_for_conv
+ + self.conv1d_weights[d_c]
+ .broadcast_mul(&state.prev_xs[li][(d_c + 1 + state.pos) % D_CONV])?)?;
+ }
+ let proj_for_conv = candle_nn::ops::silu(&proj_for_conv)?;
+ // SSM + Selection, we're doing inference here so only need the last step of
+ // the sequence.
+ // Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf
+
+ let x_proj = self.x_proj.forward(&proj_for_conv)?;
+ let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?;
+ let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?;
+ let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?;
+
+ let delta = delta.apply(&self.dt_proj)?;
+ // softplus
+ let delta = (delta.exp()? + 1.)?.log()?;
+ let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
+ let d = self.d.to_dtype(candle::DType::F32)?;
+
+ // Selective scan part
+ // Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
+ let delta = delta
+ .unsqueeze(D::Minus1)?
+ .broadcast_as((b_sz, self.d_inner, D_STATE))?;
+ let a = a.broadcast_as((b_sz, self.d_inner, D_STATE))?;
+ let b = b.broadcast_as((b_sz, self.d_inner, D_STATE))?;
+ let proj_for_conv_b =
+ proj_for_conv
+ .unsqueeze(D::Minus1)?
+ .broadcast_as((b_sz, self.d_inner, D_STATE))?;
+ state.hs[li] = ((&state.hs[li] * (&delta * &a)?.exp()?)? + &delta * &b * &proj_for_conv_b)?;
+ let ss = (state.hs[li]
+ .matmul(&c.unsqueeze(D::Minus1)?)?
+ .squeeze(D::Minus1)?
+ + proj_for_conv.broadcast_mul(&d)?)?;
+
+ let ys = (ss * candle_nn::ops::silu(&proj_for_silu))?;
+ ys.apply(&self.out_proj)
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct ResidualBlock {
+ mixer: MambaBlock,
+ norm: RmsNorm,
+}
+
+impl ResidualBlock {
+ pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
+ let mixer = MambaBlock::new(layer_index, cfg, vb.pp("mixer"))?;
+ Ok(Self { mixer, norm })
+ }
+
+ fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
+ self.mixer.forward(&xs.apply(&self.norm)?, state)? + xs
+ }
+}
+
+// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
+#[derive(Clone, Debug)]
+pub struct Model {
+ embedding: candle_nn::Embedding,
+ layers: Vec<ResidualBlock>,
+ norm_f: RmsNorm,
+ lm_head: Linear,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
+ let mut layers = Vec::with_capacity(cfg.n_layer);
+ let vb_l = vb.pp("layers");
+ for layer_idx in 0..cfg.n_layer {
+ let layer = ResidualBlock::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
+ let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
+ Ok(Self {
+ embedding,
+ layers,
+ norm_f,
+ lm_head,
+ })
+ }
+
+ pub fn forward(&self, input_ids: &Tensor, state: &mut State) -> Result<Tensor> {
+ let _b_size = input_ids.dims1()?;
+ let mut xs = self.embedding.forward(input_ids)?;
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, state)?
+ }
+ state.pos += 1;
+ xs.apply(&self.norm_f)?.apply(&self.lm_head)
+ }
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index f3782fff..769fd650 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -13,6 +13,7 @@ pub mod jina_bert;
pub mod llama;
pub mod llama2_c;
pub mod llama2_c_weights;
+pub mod mamba;
pub mod marian;
pub mod mistral;
pub mod mixformer;