summaryrefslogtreecommitdiff
path: root/candle-examples/examples/jina-bert/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/jina-bert/main.rs')
-rw-r--r--candle-examples/examples/jina-bert/main.rs28
1 files changed, 23 insertions, 5 deletions
diff --git a/candle-examples/examples/jina-bert/main.rs b/candle-examples/examples/jina-bert/main.rs
index ffde777d..d959d4cb 100644
--- a/candle-examples/examples/jina-bert/main.rs
+++ b/candle-examples/examples/jina-bert/main.rs
@@ -35,19 +35,37 @@ struct Args {
normalize_embeddings: bool,
#[arg(long)]
- tokenizer: String,
+ tokenizer: Option<String>,
#[arg(long)]
- model: String,
+ model: Option<String>,
}
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
+ use hf_hub::{api::sync::Api, Repo, RepoType};
+ let model = match &self.model {
+ Some(model_file) => std::path::PathBuf::from(model_file),
+ None => Api::new()?
+ .repo(Repo::new(
+ "jinaai/jina-embeddings-v2-base-en".to_string(),
+ RepoType::Model,
+ ))
+ .get("model.safetensors")?,
+ };
+ let tokenizer = match &self.tokenizer {
+ Some(file) => std::path::PathBuf::from(file),
+ None => Api::new()?
+ .repo(Repo::new(
+ "sentence-transformers/all-MiniLM-L6-v2".to_string(),
+ RepoType::Model,
+ ))
+ .get("tokenizer.json")?,
+ };
let device = candle_examples::device(self.cpu)?;
let config = Config::v2_base();
- let tokenizer = tokenizers::Tokenizer::from_file(&self.tokenizer).map_err(E::msg)?;
- let vb =
- unsafe { VarBuilder::from_mmaped_safetensors(&[&self.model], DType::F32, &device)? };
+ let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = BertModel::new(vb, &config)?;
Ok((model, tokenizer))
}