diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-12-23 10:46:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-23 10:46:02 +0100 |
commit | d8b9a727fc611e5690d71db3ca184d30cbd86dbc (patch) | |
tree | 33615fec2dd3a0ad5bfe6b492535b624f27375c2 /candle-examples | |
parent | ceb78d3e28977389d88f676ff24dd07fd602ae96 (diff) | |
download | candle-d8b9a727fc611e5690d71db3ca184d30cbd86dbc.tar.gz candle-d8b9a727fc611e5690d71db3ca184d30cbd86dbc.tar.bz2 candle-d8b9a727fc611e5690d71db3ca184d30cbd86dbc.zip |
Support different mamba models. (#1471)
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/mamba-minimal/main.rs | 59 |
1 files changed, 52 insertions, 7 deletions
diff --git a/candle-examples/examples/mamba-minimal/main.rs b/candle-examples/examples/mamba-minimal/main.rs index 488027f7..c446bfd3 100644 --- a/candle-examples/examples/mamba-minimal/main.rs +++ b/candle-examples/examples/mamba-minimal/main.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::{Error as E, Result}; -use clap::Parser; +use clap::{Parser, ValueEnum}; mod model; use model::{Config, Model}; @@ -111,6 +111,46 @@ impl TextGeneration { } } +#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)] +enum Which { + Mamba130m, + Mamba370m, + Mamba790m, + Mamba1_4b, + Mamba2_8b, + Mamba2_8bSlimPj, +} + +impl std::fmt::Display for Which { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl Which { + fn model_id(&self) -> &'static str { + match self { + Self::Mamba130m => "state-spaces/mamba-130m", + Self::Mamba370m => "state-spaces/mamba-370m", + Self::Mamba790m => "state-spaces/mamba-790m", + Self::Mamba1_4b => "state-spaces/mamba-1.4b", + Self::Mamba2_8b => "state-spaces/mamba-2.8b", + Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'", + } + } + + fn revision(&self) -> &'static str { + match self { + Self::Mamba130m + | Self::Mamba370m + | Self::Mamba790m + | Self::Mamba1_4b + | Self::Mamba2_8b + | Self::Mamba2_8bSlimPj => "refs/pr/1", + } + } +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -141,11 +181,14 @@ struct Args { #[arg(long, short = 'n', default_value_t = 5000)] sample_len: usize, - #[arg(long, default_value = "state-spaces/mamba-130m")] - model_id: String, + #[arg(long, default_value = "mamba130m")] + which: Which, - #[arg(long, default_value = "refs/pr/1")] - revision: String, + #[arg(long)] + model_id: Option<String>, + + #[arg(long)] + revision: Option<String>, #[arg(long)] tokenizer_file: Option<String>, @@ -194,9 +237,11 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let api = Api::new()?; let repo = api.repo(Repo::with_revision( - args.model_id, + args.model_id + .unwrap_or_else(|| args.which.model_id().to_string()), RepoType::Model, - args.revision, + args.revision + .unwrap_or_else(|| args.which.revision().to_string()), )); let tokenizer_filename = match args.tokenizer_file { Some(file) => std::path::PathBuf::from(file), |