diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-12-25 09:31:24 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-25 09:31:24 +0100 |
commit | 7135791dd5da6110e24057d82f0c2280cb732b4e (patch) | |
tree | 1b77539c1494ddebbc5d610b9248c514e5b6b97b /candle-examples/examples/mistral/main.rs | |
parent | 88589d88153bef3316a13741bd12bf5e7963957a (diff) | |
download | candle-7135791dd5da6110e24057d82f0c2280cb732b4e.tar.gz candle-7135791dd5da6110e24057d82f0c2280cb732b4e.tar.bz2 candle-7135791dd5da6110e24057d82f0c2280cb732b4e.zip |
Fix the quantized mistral example. (#1478)
Diffstat (limited to 'candle-examples/examples/mistral/main.rs')
-rw-r--r-- | candle-examples/examples/mistral/main.rs | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 55d08e6e..2b31142e 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -155,8 +155,8 @@ struct Args { #[arg(long, short = 'n', default_value_t = 100)] sample_len: usize, - #[arg(long, default_value = "mistralai/Mistral-7B-v0.1")] - model_id: String, + #[arg(long)] + model_id: Option<String>, #[arg(long, default_value = "main")] revision: String, @@ -207,8 +207,18 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let api = Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id, + None => { + if args.quantized { + "lmz/candle-mistral".to_string() + } else { + "mistralai/Mistral-7B-v0.1".to_string() + } + } + }; let repo = api.repo(Repo::with_revision( - args.model_id, + model_id, RepoType::Model, args.revision, )); |