diff options
Diffstat (limited to 'candle-examples/examples/moondream/main.rs')
-rw-r--r-- | candle-examples/examples/moondream/main.rs | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs index 2ec04256..3e0f6d57 100644 --- a/candle-examples/examples/moondream/main.rs +++ b/candle-examples/examples/moondream/main.rs @@ -155,6 +155,18 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + #[arg(long, default_value = "vikhyatk/moondream2")] + model_id: String, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + model_file: Option<String>, + + #[arg(long)] + tokenizer_file: Option<String>, } /// Loads an image from disk using the image crate, this returns a tensor with shape @@ -204,9 +216,19 @@ async fn main() -> anyhow::Result<()> { let start = std::time::Instant::now(); let api = hf_hub::api::tokio::Api::new()?; - let repo = api.model("vikhyatk/moondream2".to_string()); - let model_file = repo.get("model.safetensors").await?; - let tokenizer = repo.get("tokenizer.json").await?; + let repo = api.repo(hf_hub::Repo::with_revision( + args.model_id, + hf_hub::RepoType::Model, + args.revision, + )); + let model_file = match args.model_file { + Some(m) => m.into(), + None => repo.get("model.safetensors").await?, + }; + let tokenizer = match args.tokenizer_file { + Some(m) => m.into(), + None => repo.get("tokenizer.json").await?, + }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; |