summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/jina-bert/main.rs40
-rw-r--r--candle-transformers/src/models/jina_bert.rs30
2 files changed, 58 insertions, 12 deletions
diff --git a/candle-examples/examples/jina-bert/main.rs b/candle-examples/examples/jina-bert/main.rs
index d959d4cb..04b0c2d5 100644
--- a/candle-examples/examples/jina-bert/main.rs
+++ b/candle-examples/examples/jina-bert/main.rs
@@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use candle_transformers::models::jina_bert::{BertModel, Config};
+use candle_transformers::models::jina_bert::{BertModel, Config, PositionEmbeddingType};
use anyhow::Error as E;
use candle::{DType, Module, Tensor};
@@ -39,16 +39,25 @@ struct Args {
#[arg(long)]
model: Option<String>,
+
+ #[arg(long)]
+ model_file: Option<String>,
}
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
use hf_hub::{api::sync::Api, Repo, RepoType};
- let model = match &self.model {
+ let default = "jinaai/jina-embeddings-v2-base-en".to_string();
+ let model_name = match &self.model {
+ Some(model) => model,
+ None => &default,
+ };
+
+ let model = match &self.model_file {
Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()?
.repo(Repo::new(
- "jinaai/jina-embeddings-v2-base-en".to_string(),
+ model_name.to_string(),
RepoType::Model,
))
.get("model.safetensors")?,
@@ -57,14 +66,14 @@ impl Args {
Some(file) => std::path::PathBuf::from(file),
None => Api::new()?
.repo(Repo::new(
- "sentence-transformers/all-MiniLM-L6-v2".to_string(),
+ model_name.to_string(),
RepoType::Model,
))
.get("tokenizer.json")?,
};
let device = candle_examples::device(self.cpu)?;
- let config = Config::v2_base();
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
+ let config = Config::new(tokenizer.get_vocab_size(true), 768, 12, 12, 3072, candle_nn::Activation::Gelu, 8192, 2, 0.02, 1e-12, 0, PositionEmbeddingType::Alibi);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = BertModel::new(vb, &config)?;
Ok((model, tokenizer))
@@ -101,14 +110,21 @@ fn main() -> anyhow::Result<()> {
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
- for idx in 0..args.n {
- let start = std::time::Instant::now();
- let ys = model.forward(&token_ids)?;
- if idx == 0 {
- println!("{ys}");
- }
- println!("Took {:?}", start.elapsed());
+ let start = std::time::Instant::now();
+ let embeddings = model.forward(&token_ids)?;
+ let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
+ let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
+ println!("pooled_embeddigns: {embeddings}");
+ let embeddings = if args.normalize_embeddings {
+ normalize_l2(&embeddings)?
+ } else {
+ embeddings
+ };
+ if args.normalize_embeddings {
+ println!("normalized_embeddings: {embeddings}");
}
+ println!("Took {:?}", start.elapsed());
+
} else {
let sentences = [
"The cat sits outside",
diff --git a/candle-transformers/src/models/jina_bert.rs b/candle-transformers/src/models/jina_bert.rs
index 7e3c3887..97bc1b25 100644
--- a/candle-transformers/src/models/jina_bert.rs
+++ b/candle-transformers/src/models/jina_bert.rs
@@ -47,6 +47,36 @@ impl Config {
position_embedding_type: PositionEmbeddingType::Alibi,
}
}
+
+ pub fn new(
+ vocab_size: usize,
+ hidden_size: usize,
+ num_hidden_layers: usize,
+ num_attention_heads: usize,
+ intermediate_size: usize,
+ hidden_act: candle_nn::Activation,
+ max_position_embeddings: usize,
+ type_vocab_size: usize,
+ initializer_range: f64,
+ layer_norm_eps: f64,
+ pad_token_id: usize,
+ position_embedding_type: PositionEmbeddingType,
+ ) -> Self {
+ Config {
+ vocab_size,
+ hidden_size,
+ num_hidden_layers,
+ num_attention_heads,
+ intermediate_size,
+ hidden_act,
+ max_position_embeddings,
+ type_vocab_size,
+ initializer_range,
+ layer_norm_eps,
+ pad_token_id,
+ position_embedding_type,
+ }
+ }
}
#[derive(Clone, Debug)]