summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/bert.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/bert.rs')
-rw-r--r--candle-transformers/src/models/bert.rs568
1 files changed, 568 insertions, 0 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
new file mode 100644
index 00000000..3f164a3a
--- /dev/null
+++ b/candle-transformers/src/models/bert.rs
@@ -0,0 +1,568 @@
+use candle::{DType, Device, Result, Tensor};
+use candle_nn::{Embedding, Module, VarBuilder};
+use serde::Deserialize;
+
+pub const DTYPE: DType = DType::F32;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
+#[serde(rename_all = "lowercase")]
+enum HiddenAct {
+ Gelu,
+ Relu,
+}
+
+struct HiddenActLayer {
+ act: HiddenAct,
+ span: tracing::Span,
+}
+
+impl HiddenActLayer {
+ fn new(act: HiddenAct) -> Self {
+ let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
+ Self { act, span }
+ }
+
+ fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
+ let _enter = self.span.enter();
+ match self.act {
+ // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
+ // small numerical difference.
+ // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
+ HiddenAct::Gelu => xs.gelu(),
+ HiddenAct::Relu => xs.relu(),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+ weight: Tensor,
+ bias: Option<Tensor>,
+ span: tracing::Span,
+}
+
+impl Linear {
+ pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Self { weight, bias, span }
+ }
+
+ pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
+ let _enter = self.span.enter();
+ let w = match x.dims() {
+ &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
+ _ => self.weight.t()?,
+ };
+ let x = x.matmul(&w)?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => x.broadcast_add(bias),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct LayerNorm {
+ weight: Tensor,
+ bias: Tensor,
+ eps: f64,
+ span: tracing::Span,
+}
+
+impl LayerNorm {
+ pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
+ let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
+ Self {
+ weight,
+ bias,
+ eps,
+ span,
+ }
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let x_dtype = x.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+ let (_bsize, _seq_len, hidden_size) = x.dims3()?;
+ let x = x.to_dtype(internal_dtype)?;
+ let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
+ let x = x.broadcast_sub(&mean_x)?;
+ let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
+ let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
+ let x = x_normed
+ .to_dtype(x_dtype)?
+ .broadcast_mul(&self.weight)?
+ .broadcast_add(&self.bias)?;
+ Ok(x)
+ }
+}
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
+#[serde(rename_all = "lowercase")]
+enum PositionEmbeddingType {
+ #[default]
+ Absolute,
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ vocab_size: usize,
+ hidden_size: usize,
+ num_hidden_layers: usize,
+ num_attention_heads: usize,
+ intermediate_size: usize,
+ hidden_act: HiddenAct,
+ hidden_dropout_prob: f64,
+ max_position_embeddings: usize,
+ type_vocab_size: usize,
+ initializer_range: f64,
+ layer_norm_eps: f64,
+ pad_token_id: usize,
+ #[serde(default)]
+ position_embedding_type: PositionEmbeddingType,
+ #[serde(default)]
+ use_cache: bool,
+ classifier_dropout: Option<f64>,
+ model_type: Option<String>,
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ vocab_size: 30522,
+ hidden_size: 768,
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ intermediate_size: 3072,
+ hidden_act: HiddenAct::Gelu,
+ hidden_dropout_prob: 0.1,
+ max_position_embeddings: 512,
+ type_vocab_size: 2,
+ initializer_range: 0.02,
+ layer_norm_eps: 1e-12,
+ pad_token_id: 0,
+ position_embedding_type: PositionEmbeddingType::Absolute,
+ use_cache: true,
+ classifier_dropout: None,
+ model_type: Some("bert".to_string()),
+ }
+ }
+}
+
+impl Config {
+ fn _all_mini_lm_l6_v2() -> Self {
+ // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
+ Self {
+ vocab_size: 30522,
+ hidden_size: 384,
+ num_hidden_layers: 6,
+ num_attention_heads: 12,
+ intermediate_size: 1536,
+ hidden_act: HiddenAct::Gelu,
+ hidden_dropout_prob: 0.1,
+ max_position_embeddings: 512,
+ type_vocab_size: 2,
+ initializer_range: 0.02,
+ layer_norm_eps: 1e-12,
+ pad_token_id: 0,
+ position_embedding_type: PositionEmbeddingType::Absolute,
+ use_cache: true,
+ classifier_dropout: None,
+ model_type: Some("bert".to_string()),
+ }
+ }
+}
+
+fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
+ Ok(Embedding::new(embeddings, hidden_size))
+}
+
+fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
+ let bias = vb.get(size2, "bias")?;
+ Ok(Linear::new(weight, Some(bias)))
+}
+
+struct Dropout {
+ #[allow(dead_code)]
+ pr: f64,
+}
+
+impl Dropout {
+ fn new(pr: f64) -> Self {
+ Self { pr }
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ // TODO
+ Ok(x.clone())
+ }
+}
+
+fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
+ let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
+ (Ok(weight), Ok(bias)) => (weight, bias),
+ (Err(err), _) | (_, Err(err)) => {
+ if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
+ (weight, bias)
+ } else {
+ return Err(err);
+ }
+ }
+ };
+ Ok(LayerNorm::new(weight, bias, eps))
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
+struct BertEmbeddings {
+ word_embeddings: Embedding,
+ position_embeddings: Option<Embedding>,
+ token_type_embeddings: Embedding,
+ layer_norm: LayerNorm,
+ dropout: Dropout,
+ span: tracing::Span,
+}
+
+impl BertEmbeddings {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let word_embeddings = embedding(
+ config.vocab_size,
+ config.hidden_size,
+ vb.pp("word_embeddings"),
+ )?;
+ let position_embeddings = embedding(
+ config.max_position_embeddings,
+ config.hidden_size,
+ vb.pp("position_embeddings"),
+ )?;
+ let token_type_embeddings = embedding(
+ config.type_vocab_size,
+ config.hidden_size,
+ vb.pp("token_type_embeddings"),
+ )?;
+ let layer_norm = layer_norm(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("LayerNorm"),
+ )?;
+ Ok(Self {
+ word_embeddings,
+ position_embeddings: Some(position_embeddings),
+ token_type_embeddings,
+ layer_norm,
+ dropout: Dropout::new(config.hidden_dropout_prob),
+ span: tracing::span!(tracing::Level::TRACE, "embeddings"),
+ })
+ }
+
+ fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (_bsize, seq_len) = input_ids.dims2()?;
+ let input_embeddings = self.word_embeddings.forward(input_ids)?;
+ let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
+ let mut embeddings = (&input_embeddings + token_type_embeddings)?;
+ if let Some(position_embeddings) = &self.position_embeddings {
+ // TODO: Proper absolute positions?
+ let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
+ let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
+ embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
+ }
+ let embeddings = self.layer_norm.forward(&embeddings)?;
+ let embeddings = self.dropout.forward(&embeddings)?;
+ Ok(embeddings)
+ }
+}
+
+struct BertSelfAttention {
+ query: Linear,
+ key: Linear,
+ value: Linear,
+ dropout: Dropout,
+ num_attention_heads: usize,
+ attention_head_size: usize,
+ span: tracing::Span,
+ span_softmax: tracing::Span,
+}
+
+impl BertSelfAttention {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let attention_head_size = config.hidden_size / config.num_attention_heads;
+ let all_head_size = config.num_attention_heads * attention_head_size;
+ let dropout = Dropout::new(config.hidden_dropout_prob);
+ let hidden_size = config.hidden_size;
+ let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
+ let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
+ let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
+ Ok(Self {
+ query,
+ key,
+ value,
+ dropout,
+ num_attention_heads: config.num_attention_heads,
+ attention_head_size,
+ span: tracing::span!(tracing::Level::TRACE, "self-attn"),
+ span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
+ })
+ }
+
+ fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut new_x_shape = xs.dims().to_vec();
+ new_x_shape.pop();
+ new_x_shape.push(self.num_attention_heads);
+ new_x_shape.push(self.attention_head_size);
+ let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
+ xs.contiguous()
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let query_layer = self.query.forward(hidden_states)?;
+ let key_layer = self.key.forward(hidden_states)?;
+ let value_layer = self.value.forward(hidden_states)?;
+
+ let query_layer = self.transpose_for_scores(&query_layer)?;
+ let key_layer = self.transpose_for_scores(&key_layer)?;
+ let value_layer = self.transpose_for_scores(&value_layer)?;
+
+ let attention_scores = query_layer.matmul(&key_layer.t()?)?;
+ let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
+ let attention_probs = {
+ let _enter_sm = self.span_softmax.enter();
+ candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
+ };
+ let attention_probs = self.dropout.forward(&attention_probs)?;
+
+ let context_layer = attention_probs.matmul(&value_layer)?;
+ let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
+ let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
+ Ok(context_layer)
+ }
+}
+
+struct BertSelfOutput {
+ dense: Linear,
+ layer_norm: LayerNorm,
+ dropout: Dropout,
+ span: tracing::Span,
+}
+
+impl BertSelfOutput {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
+ let layer_norm = layer_norm(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("LayerNorm"),
+ )?;
+ let dropout = Dropout::new(config.hidden_dropout_prob);
+ Ok(Self {
+ dense,
+ layer_norm,
+ dropout,
+ span: tracing::span!(tracing::Level::TRACE, "self-out"),
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_states = self.dense.forward(hidden_states)?;
+ let hidden_states = self.dropout.forward(&hidden_states)?;
+ self.layer_norm.forward(&(hidden_states + input_tensor)?)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
+struct BertAttention {
+ self_attention: BertSelfAttention,
+ self_output: BertSelfOutput,
+ span: tracing::Span,
+}
+
+impl BertAttention {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
+ let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
+ Ok(Self {
+ self_attention,
+ self_output,
+ span: tracing::span!(tracing::Level::TRACE, "attn"),
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let self_outputs = self.self_attention.forward(hidden_states)?;
+ let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
+ Ok(attention_output)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
+struct BertIntermediate {
+ dense: Linear,
+ intermediate_act: HiddenActLayer,
+ span: tracing::Span,
+}
+
+impl BertIntermediate {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
+ Ok(Self {
+ dense,
+ intermediate_act: HiddenActLayer::new(config.hidden_act),
+ span: tracing::span!(tracing::Level::TRACE, "inter"),
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_states = self.dense.forward(hidden_states)?;
+ let ys = self.intermediate_act.forward(&hidden_states)?;
+ Ok(ys)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
+struct BertOutput {
+ dense: Linear,
+ layer_norm: LayerNorm,
+ dropout: Dropout,
+ span: tracing::Span,
+}
+
+impl BertOutput {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
+ let layer_norm = layer_norm(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("LayerNorm"),
+ )?;
+ let dropout = Dropout::new(config.hidden_dropout_prob);
+ Ok(Self {
+ dense,
+ layer_norm,
+ dropout,
+ span: tracing::span!(tracing::Level::TRACE, "out"),
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_states = self.dense.forward(hidden_states)?;
+ let hidden_states = self.dropout.forward(&hidden_states)?;
+ self.layer_norm.forward(&(hidden_states + input_tensor)?)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
+struct BertLayer {
+ attention: BertAttention,
+ intermediate: BertIntermediate,
+ output: BertOutput,
+ span: tracing::Span,
+}
+
+impl BertLayer {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let attention = BertAttention::load(vb.pp("attention"), config)?;
+ let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
+ let output = BertOutput::load(vb.pp("output"), config)?;
+ Ok(Self {
+ attention,
+ intermediate,
+ output,
+ span: tracing::span!(tracing::Level::TRACE, "layer"),
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let attention_output = self.attention.forward(hidden_states)?;
+ // TODO: Support cross-attention?
+ // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
+ // TODO: Support something similar to `apply_chunking_to_forward`?
+ let intermediate_output = self.intermediate.forward(&attention_output)?;
+ let layer_output = self
+ .output
+ .forward(&intermediate_output, &attention_output)?;
+ Ok(layer_output)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
+struct BertEncoder {
+ layers: Vec<BertLayer>,
+ span: tracing::Span,
+}
+
+impl BertEncoder {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let layers = (0..config.num_hidden_layers)
+ .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
+ .collect::<Result<Vec<_>>>()?;
+ let span = tracing::span!(tracing::Level::TRACE, "encoder");
+ Ok(BertEncoder { layers, span })
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut hidden_states = hidden_states.clone();
+ // Use a loop rather than a fold as it's easier to modify when adding debug/...
+ for layer in self.layers.iter() {
+ hidden_states = layer.forward(&hidden_states)?
+ }
+ Ok(hidden_states)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
+pub struct BertModel {
+ embeddings: BertEmbeddings,
+ encoder: BertEncoder,
+ pub device: Device,
+ span: tracing::Span,
+}
+
+impl BertModel {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let (embeddings, encoder) = match (
+ BertEmbeddings::load(vb.pp("embeddings"), config),
+ BertEncoder::load(vb.pp("encoder"), config),
+ ) {
+ (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
+ (Err(err), _) | (_, Err(err)) => {
+ if let Some(model_type) = &config.model_type {
+ if let (Ok(embeddings), Ok(encoder)) = (
+ BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
+ BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
+ ) {
+ (embeddings, encoder)
+ } else {
+ return Err(err);
+ }
+ } else {
+ return Err(err);
+ }
+ }
+ };
+ Ok(Self {
+ embeddings,
+ encoder,
+ device: vb.device().clone(),
+ span: tracing::span!(tracing::Level::TRACE, "model"),
+ })
+ }
+
+ pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
+ let sequence_output = self.encoder.forward(&embedding_output)?;
+ Ok(sequence_output)
+ }
+}