summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-12-23 10:46:02 +0100
committerGitHub <noreply@github.com>2023-12-23 10:46:02 +0100
commitd8b9a727fc611e5690d71db3ca184d30cbd86dbc (patch)
tree33615fec2dd3a0ad5bfe6b492535b624f27375c2 /candle-examples
parentceb78d3e28977389d88f676ff24dd07fd602ae96 (diff)
downloadcandle-d8b9a727fc611e5690d71db3ca184d30cbd86dbc.tar.gz
candle-d8b9a727fc611e5690d71db3ca184d30cbd86dbc.tar.bz2
candle-d8b9a727fc611e5690d71db3ca184d30cbd86dbc.zip
Support different mamba models. (#1471)
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/mamba-minimal/main.rs59
1 files changed, 52 insertions, 7 deletions
diff --git a/candle-examples/examples/mamba-minimal/main.rs b/candle-examples/examples/mamba-minimal/main.rs
index 488027f7..c446bfd3 100644
--- a/candle-examples/examples/mamba-minimal/main.rs
+++ b/candle-examples/examples/mamba-minimal/main.rs
@@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::{Error as E, Result};
-use clap::Parser;
+use clap::{Parser, ValueEnum};
mod model;
use model::{Config, Model};
@@ -111,6 +111,46 @@ impl TextGeneration {
}
}
+#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
+enum Which {
+ Mamba130m,
+ Mamba370m,
+ Mamba790m,
+ Mamba1_4b,
+ Mamba2_8b,
+ Mamba2_8bSlimPj,
+}
+
+impl std::fmt::Display for Which {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{:?}", self)
+ }
+}
+
+impl Which {
+ fn model_id(&self) -> &'static str {
+ match self {
+ Self::Mamba130m => "state-spaces/mamba-130m",
+ Self::Mamba370m => "state-spaces/mamba-370m",
+ Self::Mamba790m => "state-spaces/mamba-790m",
+ Self::Mamba1_4b => "state-spaces/mamba-1.4b",
+ Self::Mamba2_8b => "state-spaces/mamba-2.8b",
+ Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
+ }
+ }
+
+ fn revision(&self) -> &'static str {
+ match self {
+ Self::Mamba130m
+ | Self::Mamba370m
+ | Self::Mamba790m
+ | Self::Mamba1_4b
+ | Self::Mamba2_8b
+ | Self::Mamba2_8bSlimPj => "refs/pr/1",
+ }
+ }
+}
+
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -141,11 +181,14 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
- #[arg(long, default_value = "state-spaces/mamba-130m")]
- model_id: String,
+ #[arg(long, default_value = "mamba130m")]
+ which: Which,
- #[arg(long, default_value = "refs/pr/1")]
- revision: String,
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long)]
+ revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
@@ -194,9 +237,11 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
- args.model_id,
+ args.model_id
+ .unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
- args.revision,
+ args.revision
+ .unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),