summaryrefslogtreecommitdiff
path: root/candle-examples/examples/mistral/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-12-25 09:31:24 +0100
committerGitHub <noreply@github.com>2023-12-25 09:31:24 +0100
commit7135791dd5da6110e24057d82f0c2280cb732b4e (patch)
tree1b77539c1494ddebbc5d610b9248c514e5b6b97b /candle-examples/examples/mistral/main.rs
parent88589d88153bef3316a13741bd12bf5e7963957a (diff)
downloadcandle-7135791dd5da6110e24057d82f0c2280cb732b4e.tar.gz
candle-7135791dd5da6110e24057d82f0c2280cb732b4e.tar.bz2
candle-7135791dd5da6110e24057d82f0c2280cb732b4e.zip
Fix the quantized mistral example. (#1478)
Diffstat (limited to 'candle-examples/examples/mistral/main.rs')
-rw-r--r--candle-examples/examples/mistral/main.rs16
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index 55d08e6e..2b31142e 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -155,8 +155,8 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
- #[arg(long, default_value = "mistralai/Mistral-7B-v0.1")]
- model_id: String,
+ #[arg(long)]
+ model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
@@ -207,8 +207,18 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
+ let model_id = match args.model_id {
+ Some(model_id) => model_id,
+ None => {
+ if args.quantized {
+ "lmz/candle-mistral".to_string()
+ } else {
+ "mistralai/Mistral-7B-v0.1".to_string()
+ }
+ }
+ };
let repo = api.repo(Repo::with_revision(
- args.model_id,
+ model_id,
RepoType::Model,
args.revision,
));