diff options
Diffstat (limited to 'candle-examples/examples/jina-bert/main.rs')
-rw-r--r-- | candle-examples/examples/jina-bert/main.rs | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/candle-examples/examples/jina-bert/main.rs b/candle-examples/examples/jina-bert/main.rs index ffde777d..d959d4cb 100644 --- a/candle-examples/examples/jina-bert/main.rs +++ b/candle-examples/examples/jina-bert/main.rs @@ -35,19 +35,37 @@ struct Args { normalize_embeddings: bool, #[arg(long)] - tokenizer: String, + tokenizer: Option<String>, #[arg(long)] - model: String, + model: Option<String>, } impl Args { fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> { + use hf_hub::{api::sync::Api, Repo, RepoType}; + let model = match &self.model { + Some(model_file) => std::path::PathBuf::from(model_file), + None => Api::new()? + .repo(Repo::new( + "jinaai/jina-embeddings-v2-base-en".to_string(), + RepoType::Model, + )) + .get("model.safetensors")?, + }; + let tokenizer = match &self.tokenizer { + Some(file) => std::path::PathBuf::from(file), + None => Api::new()? + .repo(Repo::new( + "sentence-transformers/all-MiniLM-L6-v2".to_string(), + RepoType::Model, + )) + .get("tokenizer.json")?, + }; let device = candle_examples::device(self.cpu)?; let config = Config::v2_base(); - let tokenizer = tokenizers::Tokenizer::from_file(&self.tokenizer).map_err(E::msg)?; - let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[&self.model], DType::F32, &device)? }; + let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; let model = BertModel::new(vb, &config)?; Ok((model, tokenizer)) } |