summaryrefslogtreecommitdiff
path: root/candle-examples/examples/jina-bert/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-08-01 13:19:41 +0100
committerGitHub <noreply@github.com>2024-08-01 14:19:41 +0200
commit9ca277a9d71d1919228a4b994750e3d811da6b0a (patch)
treee875f07bd7edeefdf970b8e204e684dd4bd2d081 /candle-examples/examples/jina-bert/main.rs
parent2e9c010609c0516c85f35f829ea4f0e29820100f (diff)
downloadcandle-9ca277a9d71d1919228a4b994750e3d811da6b0a.tar.gz
candle-9ca277a9d71d1919228a4b994750e3d811da6b0a.tar.bz2
candle-9ca277a9d71d1919228a4b994750e3d811da6b0a.zip
Fix cargo fmt. (#2383)
* Fix cargo fmt. * Clippy fix. * Cosmetic tweaks.
Diffstat (limited to 'candle-examples/examples/jina-bert/main.rs')
-rw-r--r--candle-examples/examples/jina-bert/main.rs33
1 files changed, 19 insertions, 14 deletions
diff --git a/candle-examples/examples/jina-bert/main.rs b/candle-examples/examples/jina-bert/main.rs
index 04b0c2d5..4a969a5c 100644
--- a/candle-examples/examples/jina-bert/main.rs
+++ b/candle-examples/examples/jina-bert/main.rs
@@ -47,33 +47,39 @@ struct Args {
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
use hf_hub::{api::sync::Api, Repo, RepoType};
- let default = "jinaai/jina-embeddings-v2-base-en".to_string();
- let model_name = match &self.model {
- Some(model) => model,
- None => &default,
+ let model_name = match self.model.as_ref() {
+ Some(model) => model.to_string(),
+ None => "jinaai/jina-embeddings-v2-base-en".to_string(),
};
let model = match &self.model_file {
Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()?
- .repo(Repo::new(
- model_name.to_string(),
- RepoType::Model,
- ))
+ .repo(Repo::new(model_name.to_string(), RepoType::Model))
.get("model.safetensors")?,
};
let tokenizer = match &self.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => Api::new()?
- .repo(Repo::new(
- model_name.to_string(),
- RepoType::Model,
- ))
+ .repo(Repo::new(model_name.to_string(), RepoType::Model))
.get("tokenizer.json")?,
};
let device = candle_examples::device(self.cpu)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
- let config = Config::new(tokenizer.get_vocab_size(true), 768, 12, 12, 3072, candle_nn::Activation::Gelu, 8192, 2, 0.02, 1e-12, 0, PositionEmbeddingType::Alibi);
+ let config = Config::new(
+ tokenizer.get_vocab_size(true),
+ 768,
+ 12,
+ 12,
+ 3072,
+ candle_nn::Activation::Gelu,
+ 8192,
+ 2,
+ 0.02,
+ 1e-12,
+ 0,
+ PositionEmbeddingType::Alibi,
+ );
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = BertModel::new(vb, &config)?;
Ok((model, tokenizer))
@@ -124,7 +130,6 @@ fn main() -> anyhow::Result<()> {
println!("normalized_embeddings: {embeddings}");
}
println!("Took {:?}", start.elapsed());
-
} else {
let sentences = [
"The cat sits outside",