diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-01 13:19:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-01 14:19:41 +0200 |
commit | 9ca277a9d71d1919228a4b994750e3d811da6b0a (patch) | |
tree | e875f07bd7edeefdf970b8e204e684dd4bd2d081 /candle-examples/examples/jina-bert/main.rs | |
parent | 2e9c010609c0516c85f35f829ea4f0e29820100f (diff) | |
download | candle-9ca277a9d71d1919228a4b994750e3d811da6b0a.tar.gz candle-9ca277a9d71d1919228a4b994750e3d811da6b0a.tar.bz2 candle-9ca277a9d71d1919228a4b994750e3d811da6b0a.zip |
Fix cargo fmt. (#2383)
* Fix cargo fmt.
* Clippy fix.
* Cosmetic tweaks.
Diffstat (limited to 'candle-examples/examples/jina-bert/main.rs')
-rw-r--r-- | candle-examples/examples/jina-bert/main.rs | 33 |
1 files changed, 19 insertions, 14 deletions
diff --git a/candle-examples/examples/jina-bert/main.rs b/candle-examples/examples/jina-bert/main.rs index 04b0c2d5..4a969a5c 100644 --- a/candle-examples/examples/jina-bert/main.rs +++ b/candle-examples/examples/jina-bert/main.rs @@ -47,33 +47,39 @@ struct Args { impl Args { fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> { use hf_hub::{api::sync::Api, Repo, RepoType}; - let default = "jinaai/jina-embeddings-v2-base-en".to_string(); - let model_name = match &self.model { - Some(model) => model, - None => &default, + let model_name = match self.model.as_ref() { + Some(model) => model.to_string(), + None => "jinaai/jina-embeddings-v2-base-en".to_string(), }; let model = match &self.model_file { Some(model_file) => std::path::PathBuf::from(model_file), None => Api::new()? - .repo(Repo::new( - model_name.to_string(), - RepoType::Model, - )) + .repo(Repo::new(model_name.to_string(), RepoType::Model)) .get("model.safetensors")?, }; let tokenizer = match &self.tokenizer { Some(file) => std::path::PathBuf::from(file), None => Api::new()? - .repo(Repo::new( - model_name.to_string(), - RepoType::Model, - )) + .repo(Repo::new(model_name.to_string(), RepoType::Model)) .get("tokenizer.json")?, }; let device = candle_examples::device(self.cpu)?; let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?; - let config = Config::new(tokenizer.get_vocab_size(true), 768, 12, 12, 3072, candle_nn::Activation::Gelu, 8192, 2, 0.02, 1e-12, 0, PositionEmbeddingType::Alibi); + let config = Config::new( + tokenizer.get_vocab_size(true), + 768, + 12, + 12, + 3072, + candle_nn::Activation::Gelu, + 8192, + 2, + 0.02, + 1e-12, + 0, + PositionEmbeddingType::Alibi, + ); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; let model = BertModel::new(vb, &config)?; Ok((model, tokenizer)) @@ -124,7 +130,6 @@ fn main() -> anyhow::Result<()> { println!("normalized_embeddings: {embeddings}"); } println!("Took {:?}", start.elapsed()); - } else { let sentences = [ "The cat sits outside", |