summaryrefslogtreecommitdiff
path: root/candle-examples/examples/mistral/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/mistral/main.rs')
-rw-r--r--candle-examples/examples/mistral/main.rs20
1 files changed, 15 insertions, 5 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index 18f18e5d..2b31142e 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -155,8 +155,8 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
- #[arg(long, default_value = "lmz/candle-mistral")]
- model_id: String,
+ #[arg(long)]
+ model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
@@ -207,8 +207,18 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
+ let model_id = match args.model_id {
+ Some(model_id) => model_id,
+ None => {
+ if args.quantized {
+ "lmz/candle-mistral".to_string()
+ } else {
+ "mistralai/Mistral-7B-v0.1".to_string()
+ }
+ }
+ };
let repo = api.repo(Repo::with_revision(
- args.model_id,
+ model_id,
RepoType::Model,
args.revision,
));
@@ -226,8 +236,8 @@ fn main() -> Result<()> {
vec![repo.get("model-q4k.gguf")?]
} else {
vec![
- repo.get("pytorch_model-00001-of-00002.safetensors")?,
- repo.get("pytorch_model-00002-of-00002.safetensors")?,
+ repo.get("model-00001-of-00002.safetensors")?,
+ repo.get("model-00002-of-00002.safetensors")?,
]
}
}