summaryrefslogtreecommitdiff
path: root/candle-examples/examples/mistral/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/mistral/main.rs')
-rw-r--r--candle-examples/examples/mistral/main.rs16
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index 55d08e6e..2b31142e 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -155,8 +155,8 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
- #[arg(long, default_value = "mistralai/Mistral-7B-v0.1")]
- model_id: String,
+ #[arg(long)]
+ model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
@@ -207,8 +207,18 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
+ let model_id = match args.model_id {
+ Some(model_id) => model_id,
+ None => {
+ if args.quantized {
+ "lmz/candle-mistral".to_string()
+ } else {
+ "mistralai/Mistral-7B-v0.1".to_string()
+ }
+ }
+ };
let repo = api.repo(Repo::with_revision(
- args.model_id,
+ model_id,
RepoType::Model,
args.revision,
));