summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/splade/README.md28
-rw-r--r--candle-examples/examples/splade/main.rs210
-rw-r--r--candle-transformers/src/models/bert.rs97
3 files changed, 335 insertions, 0 deletions
diff --git a/candle-examples/examples/splade/README.md b/candle-examples/examples/splade/README.md
new file mode 100644
index 00000000..582cea27
--- /dev/null
+++ b/candle-examples/examples/splade/README.md
@@ -0,0 +1,28 @@
+# candle-splade
+
+ SPLADE is a neural retrieval model which learns query/document sparse expansion via the BERT MLM head and sparse regularization. Sparse representations benefit from several advantages compared to dense approaches: efficient use of inverted index, explicit lexical match, interpretability... They also seem to be better at generalizing on out-of-domain data. In this example we can do the following two tasks:
+
+- Compute sparse embedding for a given query.
+- Compute similarities between a set of sentences using sparse embeddings.
+
+## Sparse Sentence embeddings
+
+SPLADE is used to compute the sparse embedding for a given query. The model weights
+are downloaded from the hub on the first run. This makes use of the BertForMaskedLM model.
+
+```bash
+cargo run --example splade --release -- --prompt "Here is a test sentence"
+
+> "the out there still house inside position outside stay standing hotel sitting dog animal sit bird cat statue cats"
+> [0.10270107, 0.269471, 0.047469813, 0.0016636598, 0.05394874, 0.23105666, 0.037475716, 0.45949644, 0.009062732, 0.06790692, 0.0327835, 0.33122346, 0.16863061, 0.12688516, 0.340983, 0.044972017, 0.47724655, 0.01765311, 0.37331146]
+```
+
+```bash
+cargo run --example splade --release --features
+
+> score: 0.47 'The new movie is awesome' 'The new movie is so great'
+> score: 0.43 'The cat sits outside' 'The cat plays in the garden'
+> score: 0.14 'I love pasta' 'Do you like pizza?'
+> score: 0.11 'A man is playing guitar' 'The cat plays in the garden'
+> score: 0.05 'A man is playing guitar' 'A woman watches TV'
+```
diff --git a/candle-examples/examples/splade/main.rs b/candle-examples/examples/splade/main.rs
new file mode 100644
index 00000000..aa4c60ac
--- /dev/null
+++ b/candle-examples/examples/splade/main.rs
@@ -0,0 +1,210 @@
+use std::path::PathBuf;
+
+use anyhow::{Error as E, Result};
+use candle::Tensor;
+use candle_nn::VarBuilder;
+use candle_transformers::models::bert::{self, BertForMaskedLM, Config};
+use clap::Parser;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::{PaddingParams, Tokenizer};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long, default_value = "main")]
+ revision: String,
+
+ // Path to the tokenizer file.
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ // Path to the weight files.
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ // Path to the config file.
+ #[arg(long)]
+ config_file: Option<String>,
+
+ /// When set, compute embeddings for this prompt.
+ #[arg(long)]
+ prompt: Option<String>,
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ let api = Api::new()?;
+ let model_id = match &args.model_id {
+ Some(model_id) => model_id.to_string(),
+ None => "prithivida/Splade_PP_en_v1".to_string(),
+ };
+ let repo = api.repo(Repo::with_revision(
+ model_id,
+ RepoType::Model,
+ args.revision,
+ ));
+
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("tokenizer.json")?,
+ };
+
+ let config_filename = match args.config_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("config.json")?,
+ };
+
+ let weights_filename = match args.weight_files {
+ Some(files) => PathBuf::from(files),
+ None => match repo.get("model.safetensors") {
+ Ok(safetensors) => safetensors,
+ Err(_) => match repo.get("pytorch_model.bin") {
+ Ok(pytorch_model) => pytorch_model,
+ Err(e) => {
+ return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
+ }
+ },
+ },
+ };
+
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: Config = serde_json::from_str(&config)?;
+ let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let device = candle_examples::device(args.cpu)?;
+ let dtype = bert::DTYPE;
+
+ let vb = if weights_filename.ends_with("model.safetensors") {
+ unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], dtype, &device).unwrap() }
+ } else {
+ println!("Loading weights from pytorch_model.bin");
+ VarBuilder::from_pth(&weights_filename, dtype, &device).unwrap()
+ };
+ let model = BertForMaskedLM::load(vb, &config)?;
+
+ if let Some(prompt) = args.prompt {
+ let tokenizer = tokenizer
+ .with_padding(None)
+ .with_truncation(None)
+ .map_err(E::msg)?;
+ let tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+
+ let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
+ let token_type_ids = token_ids.zeros_like()?;
+
+ let ys = model.forward(&token_ids, &token_type_ids, None)?;
+ let vec = Tensor::log(
+ &Tensor::try_from(1.0)?
+ .to_dtype(dtype)?
+ .to_device(&device)?
+ .broadcast_add(&ys.relu()?)?,
+ )?
+ .max(1)?;
+ let vec = normalize_l2(&vec)?;
+
+ let vec = vec.squeeze(0)?.to_vec1::<f32>()?;
+
+ let indices = (0..vec.len())
+ .filter(|&i| vec[i] != 0.0)
+ .map(|x| x as u32)
+ .collect::<Vec<_>>();
+
+ let tokens = tokenizer.decode(&indices, true).unwrap();
+ println!("{tokens:?}");
+ let values = indices.iter().map(|&i| vec[i as usize]).collect::<Vec<_>>();
+ println!("{values:?}");
+ } else {
+ let sentences = [
+ "The cat sits outside",
+ "A man is playing guitar",
+ "I love pasta",
+ "The new movie is awesome",
+ "The cat plays in the garden",
+ "A woman watches TV",
+ "The new movie is so great",
+ "Do you like pizza?",
+ ];
+
+ let n_sentences = sentences.len();
+ if let Some(pp) = tokenizer.get_padding_mut() {
+ pp.strategy = tokenizers::PaddingStrategy::BatchLongest
+ } else {
+ let pp = PaddingParams {
+ strategy: tokenizers::PaddingStrategy::BatchLongest,
+ ..Default::default()
+ };
+ tokenizer.with_padding(Some(pp));
+ }
+ let tokens = tokenizer
+ .encode_batch(sentences.to_vec(), true)
+ .map_err(E::msg)?;
+ let token_ids = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_ids().to_vec();
+ Ok(Tensor::new(tokens.as_slice(), &device)?)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let attention_mask = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_attention_mask().to_vec();
+ Ok(Tensor::new(tokens.as_slice(), &device)?)
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let token_ids = Tensor::stack(&token_ids, 0)?;
+ let attention_mask = Tensor::stack(&attention_mask, 0)?;
+ let token_type_ids = token_ids.zeros_like()?;
+
+ let ys = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
+ let vector = Tensor::log(
+ &Tensor::try_from(1.0)?
+ .to_dtype(dtype)?
+ .to_device(&device)?
+ .broadcast_add(&ys.relu()?)?,
+ )?;
+ let vector = vector
+ .broadcast_mul(&attention_mask.unsqueeze(2)?.to_dtype(dtype)?)?
+ .max(1)?;
+ let vec = normalize_l2(&vector)?;
+ let mut similarities = vec![];
+ for i in 0..n_sentences {
+ let e_i = vec.get(i)?;
+ for j in (i + 1)..n_sentences {
+ let e_j = vec.get(j)?;
+ let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
+ let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
+ let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
+ let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
+ similarities.push((cosine_similarity, i, j))
+ }
+ }
+ similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
+ for &(score, i, j) in similarities[..5].iter() {
+ println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
+ }
+ }
+
+ Ok(())
+}
+
+pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
+ Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
+}
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index 354048de..bdc0385d 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -504,3 +504,100 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
}
+
+//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
+struct BertPredictionHeadTransform {
+ dense: Linear,
+ activation: HiddenActLayer,
+ layer_norm: LayerNorm,
+}
+
+impl BertPredictionHeadTransform {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
+ let activation = HiddenActLayer::new(config.hidden_act);
+ let layer_norm = layer_norm(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("LayerNorm"),
+ )?;
+ Ok(Self {
+ dense,
+ activation,
+ layer_norm,
+ })
+ }
+}
+
+impl Module for BertPredictionHeadTransform {
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let hidden_states = self
+ .activation
+ .forward(&self.dense.forward(hidden_states)?)?;
+ self.layer_norm.forward(&hidden_states)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
+pub struct BertLMPredictionHead {
+ transform: BertPredictionHeadTransform,
+ decoder: Linear,
+}
+
+impl BertLMPredictionHead {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?;
+ let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?;
+ Ok(Self { transform, decoder })
+ }
+}
+
+impl Module for BertLMPredictionHead {
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ self.decoder
+ .forward(&self.transform.forward(hidden_states)?)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
+pub struct BertOnlyMLMHead {
+ predictions: BertLMPredictionHead,
+}
+
+impl BertOnlyMLMHead {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?;
+ Ok(Self { predictions })
+ }
+}
+
+impl Module for BertOnlyMLMHead {
+ fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
+ self.predictions.forward(sequence_output)
+ }
+}
+
+pub struct BertForMaskedLM {
+ bert: BertModel,
+ cls: BertOnlyMLMHead,
+}
+
+impl BertForMaskedLM {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let bert = BertModel::load(vb.pp("bert"), config)?;
+ let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?;
+ Ok(Self { bert, cls })
+ }
+
+ pub fn forward(
+ &self,
+ input_ids: &Tensor,
+ token_type_ids: &Tensor,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let sequence_output = self
+ .bert
+ .forward(input_ids, token_type_ids, attention_mask)?;
+ self.cls.forward(&sequence_output)
+ }
+}