summaryrefslogtreecommitdiff
path: root/candle-examples/examples/moondream/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/moondream/main.rs')
-rw-r--r--candle-examples/examples/moondream/main.rs28
1 files changed, 25 insertions, 3 deletions
diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs
index 2ec04256..3e0f6d57 100644
--- a/candle-examples/examples/moondream/main.rs
+++ b/candle-examples/examples/moondream/main.rs
@@ -155,6 +155,18 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
+
+ #[arg(long, default_value = "vikhyatk/moondream2")]
+ model_id: String,
+
+ #[arg(long, default_value = "main")]
+ revision: String,
+
+ #[arg(long)]
+ model_file: Option<String>,
+
+ #[arg(long)]
+ tokenizer_file: Option<String>,
}
/// Loads an image from disk using the image crate, this returns a tensor with shape
@@ -204,9 +216,19 @@ async fn main() -> anyhow::Result<()> {
let start = std::time::Instant::now();
let api = hf_hub::api::tokio::Api::new()?;
- let repo = api.model("vikhyatk/moondream2".to_string());
- let model_file = repo.get("model.safetensors").await?;
- let tokenizer = repo.get("tokenizer.json").await?;
+ let repo = api.repo(hf_hub::Repo::with_revision(
+ args.model_id,
+ hf_hub::RepoType::Model,
+ args.revision,
+ ));
+ let model_file = match args.model_file {
+ Some(m) => m.into(),
+ None => repo.get("model.safetensors").await?,
+ };
+ let tokenizer = match args.tokenizer_file {
+ Some(m) => m.into(),
+ None => repo.get("tokenizer.json").await?,
+ };
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;