summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bert
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/bert')
-rw-r--r--candle-examples/examples/bert/main.rs3
-rw-r--r--candle-examples/examples/bert/model.rs568
2 files changed, 1 insertions, 570 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 6cee66ee..9d0eccdf 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -3,14 +3,13 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-mod model;
+use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{anyhow, Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
-use model::{BertModel, Config, DTYPE};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
diff --git a/candle-examples/examples/bert/model.rs b/candle-examples/examples/bert/model.rs
deleted file mode 100644
index 3f164a3a..00000000
--- a/candle-examples/examples/bert/model.rs
+++ /dev/null
@@ -1,568 +0,0 @@
-use candle::{DType, Device, Result, Tensor};
-use candle_nn::{Embedding, Module, VarBuilder};
-use serde::Deserialize;
-
-pub const DTYPE: DType = DType::F32;
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
-#[serde(rename_all = "lowercase")]
-enum HiddenAct {
- Gelu,
- Relu,
-}
-
-struct HiddenActLayer {
- act: HiddenAct,
- span: tracing::Span,
-}
-
-impl HiddenActLayer {
- fn new(act: HiddenAct) -> Self {
- let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
- Self { act, span }
- }
-
- fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
- let _enter = self.span.enter();
- match self.act {
- // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
- // small numerical difference.
- // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
- HiddenAct::Gelu => xs.gelu(),
- HiddenAct::Relu => xs.relu(),
- }
- }
-}
-
-#[derive(Debug)]
-pub struct Linear {
- weight: Tensor,
- bias: Option<Tensor>,
- span: tracing::Span,
-}
-
-impl Linear {
- pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
- let span = tracing::span!(tracing::Level::TRACE, "linear");
- Self { weight, bias, span }
- }
-
- pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- let _enter = self.span.enter();
- let w = match x.dims() {
- &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
- _ => self.weight.t()?,
- };
- let x = x.matmul(&w)?;
- match &self.bias {
- None => Ok(x),
- Some(bias) => x.broadcast_add(bias),
- }
- }
-}
-
-#[derive(Debug)]
-pub struct LayerNorm {
- weight: Tensor,
- bias: Tensor,
- eps: f64,
- span: tracing::Span,
-}
-
-impl LayerNorm {
- pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
- let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
- Self {
- weight,
- bias,
- eps,
- span,
- }
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let x_dtype = x.dtype();
- let internal_dtype = match x_dtype {
- DType::F16 | DType::BF16 => DType::F32,
- d => d,
- };
- let (_bsize, _seq_len, hidden_size) = x.dims3()?;
- let x = x.to_dtype(internal_dtype)?;
- let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
- let x = x.broadcast_sub(&mean_x)?;
- let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
- let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
- let x = x_normed
- .to_dtype(x_dtype)?
- .broadcast_mul(&self.weight)?
- .broadcast_add(&self.bias)?;
- Ok(x)
- }
-}
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
-#[serde(rename_all = "lowercase")]
-enum PositionEmbeddingType {
- #[default]
- Absolute,
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
-#[derive(Debug, Clone, PartialEq, Deserialize)]
-pub struct Config {
- vocab_size: usize,
- hidden_size: usize,
- num_hidden_layers: usize,
- num_attention_heads: usize,
- intermediate_size: usize,
- hidden_act: HiddenAct,
- hidden_dropout_prob: f64,
- max_position_embeddings: usize,
- type_vocab_size: usize,
- initializer_range: f64,
- layer_norm_eps: f64,
- pad_token_id: usize,
- #[serde(default)]
- position_embedding_type: PositionEmbeddingType,
- #[serde(default)]
- use_cache: bool,
- classifier_dropout: Option<f64>,
- model_type: Option<String>,
-}
-
-impl Default for Config {
- fn default() -> Self {
- Self {
- vocab_size: 30522,
- hidden_size: 768,
- num_hidden_layers: 12,
- num_attention_heads: 12,
- intermediate_size: 3072,
- hidden_act: HiddenAct::Gelu,
- hidden_dropout_prob: 0.1,
- max_position_embeddings: 512,
- type_vocab_size: 2,
- initializer_range: 0.02,
- layer_norm_eps: 1e-12,
- pad_token_id: 0,
- position_embedding_type: PositionEmbeddingType::Absolute,
- use_cache: true,
- classifier_dropout: None,
- model_type: Some("bert".to_string()),
- }
- }
-}
-
-impl Config {
- fn _all_mini_lm_l6_v2() -> Self {
- // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
- Self {
- vocab_size: 30522,
- hidden_size: 384,
- num_hidden_layers: 6,
- num_attention_heads: 12,
- intermediate_size: 1536,
- hidden_act: HiddenAct::Gelu,
- hidden_dropout_prob: 0.1,
- max_position_embeddings: 512,
- type_vocab_size: 2,
- initializer_range: 0.02,
- layer_norm_eps: 1e-12,
- pad_token_id: 0,
- position_embedding_type: PositionEmbeddingType::Absolute,
- use_cache: true,
- classifier_dropout: None,
- model_type: Some("bert".to_string()),
- }
- }
-}
-
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-
-fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- let bias = vb.get(size2, "bias")?;
- Ok(Linear::new(weight, Some(bias)))
-}
-
-struct Dropout {
- #[allow(dead_code)]
- pr: f64,
-}
-
-impl Dropout {
- fn new(pr: f64) -> Self {
- Self { pr }
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- // TODO
- Ok(x.clone())
- }
-}
-
-fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
- let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
- (Ok(weight), Ok(bias)) => (weight, bias),
- (Err(err), _) | (_, Err(err)) => {
- if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
- (weight, bias)
- } else {
- return Err(err);
- }
- }
- };
- Ok(LayerNorm::new(weight, bias, eps))
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
-struct BertEmbeddings {
- word_embeddings: Embedding,
- position_embeddings: Option<Embedding>,
- token_type_embeddings: Embedding,
- layer_norm: LayerNorm,
- dropout: Dropout,
- span: tracing::Span,
-}
-
-impl BertEmbeddings {
- fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
- let word_embeddings = embedding(
- config.vocab_size,
- config.hidden_size,
- vb.pp("word_embeddings"),
- )?;
- let position_embeddings = embedding(
- config.max_position_embeddings,
- config.hidden_size,
- vb.pp("position_embeddings"),
- )?;
- let token_type_embeddings = embedding(
- config.type_vocab_size,
- config.hidden_size,
- vb.pp("token_type_embeddings"),
- )?;
- let layer_norm = layer_norm(
- config.hidden_size,
- config.layer_norm_eps,
- vb.pp("LayerNorm"),
- )?;
- Ok(Self {
- word_embeddings,
- position_embeddings: Some(position_embeddings),
- token_type_embeddings,
- layer_norm,
- dropout: Dropout::new(config.hidden_dropout_prob),
- span: tracing::span!(tracing::Level::TRACE, "embeddings"),
- })
- }
-
- fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let (_bsize, seq_len) = input_ids.dims2()?;
- let input_embeddings = self.word_embeddings.forward(input_ids)?;
- let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
- let mut embeddings = (&input_embeddings + token_type_embeddings)?;
- if let Some(position_embeddings) = &self.position_embeddings {
- // TODO: Proper absolute positions?
- let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
- let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
- embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
- }
- let embeddings = self.layer_norm.forward(&embeddings)?;
- let embeddings = self.dropout.forward(&embeddings)?;
- Ok(embeddings)
- }
-}
-
-struct BertSelfAttention {
- query: Linear,
- key: Linear,
- value: Linear,
- dropout: Dropout,
- num_attention_heads: usize,
- attention_head_size: usize,
- span: tracing::Span,
- span_softmax: tracing::Span,
-}
-
-impl BertSelfAttention {
- fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
- let attention_head_size = config.hidden_size / config.num_attention_heads;
- let all_head_size = config.num_attention_heads * attention_head_size;
- let dropout = Dropout::new(config.hidden_dropout_prob);
- let hidden_size = config.hidden_size;
- let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
- let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
- let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
- Ok(Self {
- query,
- key,
- value,
- dropout,
- num_attention_heads: config.num_attention_heads,
- attention_head_size,
- span: tracing::span!(tracing::Level::TRACE, "self-attn"),
- span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
- })
- }
-
- fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
- let mut new_x_shape = xs.dims().to_vec();
- new_x_shape.pop();
- new_x_shape.push(self.num_attention_heads);
- new_x_shape.push(self.attention_head_size);
- let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
- xs.contiguous()
- }
-
- fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let query_layer = self.query.forward(hidden_states)?;
- let key_layer = self.key.forward(hidden_states)?;
- let value_layer = self.value.forward(hidden_states)?;
-
- let query_layer = self.transpose_for_scores(&query_layer)?;
- let key_layer = self.transpose_for_scores(&key_layer)?;
- let value_layer = self.transpose_for_scores(&value_layer)?;
-
- let attention_scores = query_layer.matmul(&key_layer.t()?)?;
- let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
- let attention_probs = {
- let _enter_sm = self.span_softmax.enter();
- candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
- };
- let attention_probs = self.dropout.forward(&attention_probs)?;
-
- let context_layer = attention_probs.matmul(&value_layer)?;
- let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
- let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
- Ok(context_layer)
- }
-}
-
-struct BertSelfOutput {
- dense: Linear,
- layer_norm: LayerNorm,
- dropout: Dropout,
- span: tracing::Span,
-}
-
-impl BertSelfOutput {
- fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
- let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
- let layer_norm = layer_norm(
- config.hidden_size,
- config.layer_norm_eps,
- vb.pp("LayerNorm"),
- )?;
- let dropout = Dropout::new(config.hidden_dropout_prob);
- Ok(Self {
- dense,
- layer_norm,
- dropout,
- span: tracing::span!(tracing::Level::TRACE, "self-out"),
- })
- }
-
- fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let hidden_states = self.dense.forward(hidden_states)?;
- let hidden_states = self.dropout.forward(&hidden_states)?;
- self.layer_norm.forward(&(hidden_states + input_tensor)?)
- }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
-struct BertAttention {
- self_attention: BertSelfAttention,
- self_output: BertSelfOutput,
- span: tracing::Span,
-}
-
-impl BertAttention {
- fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
- let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
- let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
- Ok(Self {
- self_attention,
- self_output,
- span: tracing::span!(tracing::Level::TRACE, "attn"),
- })
- }
-
- fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let self_outputs = self.self_attention.forward(hidden_states)?;
- let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
- Ok(attention_output)
- }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
-struct BertIntermediate {
- dense: Linear,
- intermediate_act: HiddenActLayer,
- span: tracing::Span,
-}
-
-impl BertIntermediate {
- fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
- let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
- Ok(Self {
- dense,
- intermediate_act: HiddenActLayer::new(config.hidden_act),
- span: tracing::span!(tracing::Level::TRACE, "inter"),
- })
- }
-
- fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let hidden_states = self.dense.forward(hidden_states)?;
- let ys = self.intermediate_act.forward(&hidden_states)?;
- Ok(ys)
- }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
-struct BertOutput {
- dense: Linear,
- layer_norm: LayerNorm,
- dropout: Dropout,
- span: tracing::Span,
-}
-
-impl BertOutput {
- fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
- let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
- let layer_norm = layer_norm(
- config.hidden_size,
- config.layer_norm_eps,
- vb.pp("LayerNorm"),
- )?;
- let dropout = Dropout::new(config.hidden_dropout_prob);
- Ok(Self {
- dense,
- layer_norm,
- dropout,
- span: tracing::span!(tracing::Level::TRACE, "out"),
- })
- }
-
- fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let hidden_states = self.dense.forward(hidden_states)?;
- let hidden_states = self.dropout.forward(&hidden_states)?;
- self.layer_norm.forward(&(hidden_states + input_tensor)?)
- }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
-struct BertLayer {
- attention: BertAttention,
- intermediate: BertIntermediate,
- output: BertOutput,
- span: tracing::Span,
-}
-
-impl BertLayer {
- fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
- let attention = BertAttention::load(vb.pp("attention"), config)?;
- let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
- let output = BertOutput::load(vb.pp("output"), config)?;
- Ok(Self {
- attention,
- intermediate,
- output,
- span: tracing::span!(tracing::Level::TRACE, "layer"),
- })
- }
-
- fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let attention_output = self.attention.forward(hidden_states)?;
- // TODO: Support cross-attention?
- // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
- // TODO: Support something similar to `apply_chunking_to_forward`?
- let intermediate_output = self.intermediate.forward(&attention_output)?;
- let layer_output = self
- .output
- .forward(&intermediate_output, &attention_output)?;
- Ok(layer_output)
- }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
-struct BertEncoder {
- layers: Vec<BertLayer>,
- span: tracing::Span,
-}
-
-impl BertEncoder {
- fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
- let layers = (0..config.num_hidden_layers)
- .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
- .collect::<Result<Vec<_>>>()?;
- let span = tracing::span!(tracing::Level::TRACE, "encoder");
- Ok(BertEncoder { layers, span })
- }
-
- fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let mut hidden_states = hidden_states.clone();
- // Use a loop rather than a fold as it's easier to modify when adding debug/...
- for layer in self.layers.iter() {
- hidden_states = layer.forward(&hidden_states)?
- }
- Ok(hidden_states)
- }
-}
-
-// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
-pub struct BertModel {
- embeddings: BertEmbeddings,
- encoder: BertEncoder,
- pub device: Device,
- span: tracing::Span,
-}
-
-impl BertModel {
- pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
- let (embeddings, encoder) = match (
- BertEmbeddings::load(vb.pp("embeddings"), config),
- BertEncoder::load(vb.pp("encoder"), config),
- ) {
- (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
- (Err(err), _) | (_, Err(err)) => {
- if let Some(model_type) = &config.model_type {
- if let (Ok(embeddings), Ok(encoder)) = (
- BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
- BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
- ) {
- (embeddings, encoder)
- } else {
- return Err(err);
- }
- } else {
- return Err(err);
- }
- }
- };
- Ok(Self {
- embeddings,
- encoder,
- device: vb.device().clone(),
- span: tracing::span!(tracing::Level::TRACE, "model"),
- })
- }
-
- pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
- let sequence_output = self.encoder.forward(&embedding_output)?;
- Ok(sequence_output)
- }
-}