summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md3
-rw-r--r--candle-examples/examples/mamba-minimal/main.rs59
2 files changed, 55 insertions, 7 deletions
diff --git a/README.md b/README.md
index 26a81642..9f6cf9da 100644
--- a/README.md
+++ b/README.md
@@ -65,6 +65,8 @@ We also provide a some command line based examples using state of the art models
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets.
+- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
+ implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28.
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
@@ -177,6 +179,7 @@ If you have an addition to this list, please submit a pull request.
- Falcon.
- StarCoder.
- Phi 1, 1.5, and 2.
+ - Minimal Mamba
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T.
diff --git a/candle-examples/examples/mamba-minimal/main.rs b/candle-examples/examples/mamba-minimal/main.rs
index 488027f7..c446bfd3 100644
--- a/candle-examples/examples/mamba-minimal/main.rs
+++ b/candle-examples/examples/mamba-minimal/main.rs
@@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::{Error as E, Result};
-use clap::Parser;
+use clap::{Parser, ValueEnum};
mod model;
use model::{Config, Model};
@@ -111,6 +111,46 @@ impl TextGeneration {
}
}
+#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
+enum Which {
+ Mamba130m,
+ Mamba370m,
+ Mamba790m,
+ Mamba1_4b,
+ Mamba2_8b,
+ Mamba2_8bSlimPj,
+}
+
+impl std::fmt::Display for Which {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{:?}", self)
+ }
+}
+
+impl Which {
+ fn model_id(&self) -> &'static str {
+ match self {
+ Self::Mamba130m => "state-spaces/mamba-130m",
+ Self::Mamba370m => "state-spaces/mamba-370m",
+ Self::Mamba790m => "state-spaces/mamba-790m",
+ Self::Mamba1_4b => "state-spaces/mamba-1.4b",
+ Self::Mamba2_8b => "state-spaces/mamba-2.8b",
+ Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
+ }
+ }
+
+ fn revision(&self) -> &'static str {
+ match self {
+ Self::Mamba130m
+ | Self::Mamba370m
+ | Self::Mamba790m
+ | Self::Mamba1_4b
+ | Self::Mamba2_8b
+ | Self::Mamba2_8bSlimPj => "refs/pr/1",
+ }
+ }
+}
+
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -141,11 +181,14 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
- #[arg(long, default_value = "state-spaces/mamba-130m")]
- model_id: String,
+ #[arg(long, default_value = "mamba130m")]
+ which: Which,
- #[arg(long, default_value = "refs/pr/1")]
- revision: String,
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long)]
+ revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
@@ -194,9 +237,11 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
- args.model_id,
+ args.model_id
+ .unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
- args.revision,
+ args.revision
+ .unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),