diff options
-rw-r--r-- | candle-examples/examples/splade/README.md | 28 | ||||
-rw-r--r-- | candle-examples/examples/splade/main.rs | 210 | ||||
-rw-r--r-- | candle-transformers/src/models/bert.rs | 97 |
3 files changed, 335 insertions, 0 deletions
diff --git a/candle-examples/examples/splade/README.md b/candle-examples/examples/splade/README.md new file mode 100644 index 00000000..582cea27 --- /dev/null +++ b/candle-examples/examples/splade/README.md @@ -0,0 +1,28 @@ +# candle-splade + + SPLADE is a neural retrieval model which learns query/document sparse expansion via the BERT MLM head and sparse regularization. Sparse representations benefit from several advantages compared to dense approaches: efficient use of inverted index, explicit lexical match, interpretability... They also seem to be better at generalizing on out-of-domain data. In this example we can do the following two tasks: + +- Compute sparse embedding for a given query. +- Compute similarities between a set of sentences using sparse embeddings. + +## Sparse Sentence embeddings + +SPLADE is used to compute the sparse embedding for a given query. The model weights +are downloaded from the hub on the first run. This makes use of the BertForMaskedLM model. + +```bash +cargo run --example splade --release -- --prompt "Here is a test sentence" + +> "the out there still house inside position outside stay standing hotel sitting dog animal sit bird cat statue cats" +> [0.10270107, 0.269471, 0.047469813, 0.0016636598, 0.05394874, 0.23105666, 0.037475716, 0.45949644, 0.009062732, 0.06790692, 0.0327835, 0.33122346, 0.16863061, 0.12688516, 0.340983, 0.044972017, 0.47724655, 0.01765311, 0.37331146] +``` + +```bash +cargo run --example splade --release --features + +> score: 0.47 'The new movie is awesome' 'The new movie is so great' +> score: 0.43 'The cat sits outside' 'The cat plays in the garden' +> score: 0.14 'I love pasta' 'Do you like pizza?' +> score: 0.11 'A man is playing guitar' 'The cat plays in the garden' +> score: 0.05 'A man is playing guitar' 'A woman watches TV' +``` diff --git a/candle-examples/examples/splade/main.rs b/candle-examples/examples/splade/main.rs new file mode 100644 index 00000000..aa4c60ac --- /dev/null +++ b/candle-examples/examples/splade/main.rs @@ -0,0 +1,210 @@ +use std::path::PathBuf; + +use anyhow::{Error as E, Result}; +use candle::Tensor; +use candle_nn::VarBuilder; +use candle_transformers::models::bert::{self, BertForMaskedLM, Config}; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::{PaddingParams, Tokenizer}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + #[arg(long)] + model_id: Option<String>, + + #[arg(long, default_value = "main")] + revision: String, + + // Path to the tokenizer file. + #[arg(long)] + tokenizer_file: Option<String>, + + // Path to the weight files. + #[arg(long)] + weight_files: Option<String>, + + // Path to the config file. + #[arg(long)] + config_file: Option<String>, + + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: Option<String>, +} + +fn main() -> Result<()> { + let args = Args::parse(); + let api = Api::new()?; + let model_id = match &args.model_id { + Some(model_id) => model_id.to_string(), + None => "prithivida/Splade_PP_en_v1".to_string(), + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + + let config_filename = match args.config_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("config.json")?, + }; + + let weights_filename = match args.weight_files { + Some(files) => PathBuf::from(files), + None => match repo.get("model.safetensors") { + Ok(safetensors) => safetensors, + Err(_) => match repo.get("pytorch_model.bin") { + Ok(pytorch_model) => pytorch_model, + Err(e) => { + return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e))); + } + }, + }, + }; + + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let device = candle_examples::device(args.cpu)?; + let dtype = bert::DTYPE; + + let vb = if weights_filename.ends_with("model.safetensors") { + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], dtype, &device).unwrap() } + } else { + println!("Loading weights from pytorch_model.bin"); + VarBuilder::from_pth(&weights_filename, dtype, &device).unwrap() + }; + let model = BertForMaskedLM::load(vb, &config)?; + + if let Some(prompt) = args.prompt { + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; + let token_type_ids = token_ids.zeros_like()?; + + let ys = model.forward(&token_ids, &token_type_ids, None)?; + let vec = Tensor::log( + &Tensor::try_from(1.0)? + .to_dtype(dtype)? + .to_device(&device)? + .broadcast_add(&ys.relu()?)?, + )? + .max(1)?; + let vec = normalize_l2(&vec)?; + + let vec = vec.squeeze(0)?.to_vec1::<f32>()?; + + let indices = (0..vec.len()) + .filter(|&i| vec[i] != 0.0) + .map(|x| x as u32) + .collect::<Vec<_>>(); + + let tokens = tokenizer.decode(&indices, true).unwrap(); + println!("{tokens:?}"); + let values = indices.iter().map(|&i| vec[i as usize]).collect::<Vec<_>>(); + println!("{values:?}"); + } else { + let sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + ]; + + let n_sentences = sentences.len(); + if let Some(pp) = tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + } + let tokens = tokenizer + .encode_batch(sentences.to_vec(), true) + .map_err(E::msg)?; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Ok(Tensor::new(tokens.as_slice(), &device)?) + }) + .collect::<Result<Vec<_>>>()?; + let attention_mask = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_attention_mask().to_vec(); + Ok(Tensor::new(tokens.as_slice(), &device)?) + }) + .collect::<Result<Vec<_>>>()?; + + let token_ids = Tensor::stack(&token_ids, 0)?; + let attention_mask = Tensor::stack(&attention_mask, 0)?; + let token_type_ids = token_ids.zeros_like()?; + + let ys = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?; + let vector = Tensor::log( + &Tensor::try_from(1.0)? + .to_dtype(dtype)? + .to_device(&device)? + .broadcast_add(&ys.relu()?)?, + )?; + let vector = vector + .broadcast_mul(&attention_mask.unsqueeze(2)?.to_dtype(dtype)?)? + .max(1)?; + let vec = normalize_l2(&vector)?; + let mut similarities = vec![]; + for i in 0..n_sentences { + let e_i = vec.get(i)?; + for j in (i + 1)..n_sentences { + let e_j = vec.get(j)?; + let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?; + let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); + similarities.push((cosine_similarity, i, j)) + } + } + similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); + for &(score, i, j) in similarities[..5].iter() { + println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) + } + } + + Ok(()) +} + +pub fn normalize_l2(v: &Tensor) -> Result<Tensor> { + Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) +} diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 354048de..bdc0385d 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -504,3 +504,100 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result< (attention_mask.ones_like()? - &attention_mask)? .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?) } + +//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766 +struct BertPredictionHeadTransform { + dense: Linear, + activation: HiddenActLayer, + layer_norm: LayerNorm, +} + +impl BertPredictionHeadTransform { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; + let activation = HiddenActLayer::new(config.hidden_act); + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + Ok(Self { + dense, + activation, + layer_norm, + }) + } +} + +impl Module for BertPredictionHeadTransform { + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let hidden_states = self + .activation + .forward(&self.dense.forward(hidden_states)?)?; + self.layer_norm.forward(&hidden_states) + } +} + +// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1 +pub struct BertLMPredictionHead { + transform: BertPredictionHeadTransform, + decoder: Linear, +} + +impl BertLMPredictionHead { + pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?; + let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?; + Ok(Self { transform, decoder }) + } +} + +impl Module for BertLMPredictionHead { + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + self.decoder + .forward(&self.transform.forward(hidden_states)?) + } +} + +// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792 +pub struct BertOnlyMLMHead { + predictions: BertLMPredictionHead, +} + +impl BertOnlyMLMHead { + pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?; + Ok(Self { predictions }) + } +} + +impl Module for BertOnlyMLMHead { + fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> { + self.predictions.forward(sequence_output) + } +} + +pub struct BertForMaskedLM { + bert: BertModel, + cls: BertOnlyMLMHead, +} + +impl BertForMaskedLM { + pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let bert = BertModel::load(vb.pp("bert"), config)?; + let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?; + Ok(Self { bert, cls }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: &Tensor, + attention_mask: Option<&Tensor>, + ) -> Result<Tensor> { + let sequence_output = self + .bert + .forward(input_ids, token_type_ids, attention_mask)?; + self.cls.forward(&sequence_output) + } +} |