summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/distilbert.rs
diff options
context:
space:
mode:
authorOdunayo <ogundepoodunayo@gmail.com>2023-11-24 10:09:14 -0500
committerGitHub <noreply@github.com>2023-11-24 15:09:14 +0000
commit762e996ce636c5168ec1dc3ded9bd729d4f14d84 (patch)
treeebf9173e8d7510249e087d85c0f960e12c949ca5 /candle-transformers/src/models/distilbert.rs
parentca19a9af6220366bb0beaea1eb7f34a1a8a2e07b (diff)
downloadcandle-762e996ce636c5168ec1dc3ded9bd729d4f14d84.tar.gz
candle-762e996ce636c5168ec1dc3ded9bd729d4f14d84.tar.bz2
candle-762e996ce636c5168ec1dc3ded9bd729d4f14d84.zip
Distibert (#1366)
* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting * distilbet files * Apply various cleanups. * More cleanups. * More polish. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers/src/models/distilbert.rs')
-rw-r--r--candle-transformers/src/models/distilbert.rs342
1 files changed, 342 insertions, 0 deletions
diff --git a/candle-transformers/src/models/distilbert.rs b/candle-transformers/src/models/distilbert.rs
new file mode 100644
index 00000000..ea074c97
--- /dev/null
+++ b/candle-transformers/src/models/distilbert.rs
@@ -0,0 +1,342 @@
+use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
+use candle::{DType, Device, Result, Tensor};
+use candle_nn::{Embedding, Module, VarBuilder};
+use serde::Deserialize;
+
+pub const DTYPE: DType = DType::F32;
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
+#[serde(rename_all = "lowercase")]
+enum HiddenAct {
+ Gelu,
+ Relu,
+}
+
+struct HiddenActLayer {
+ act: HiddenAct,
+ span: tracing::Span,
+}
+
+impl HiddenActLayer {
+ fn new(act: HiddenAct) -> Self {
+ let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
+ Self { act, span }
+ }
+}
+
+impl Module for HiddenActLayer {
+ fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
+ let _enter = self.span.enter();
+ match self.act {
+ // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
+ HiddenAct::Gelu => xs.gelu(),
+ HiddenAct::Relu => xs.relu(),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
+#[serde(rename_all = "lowercase")]
+enum PositionEmbeddingType {
+ #[default]
+ Absolute,
+}
+
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ vocab_size: usize,
+ dim: usize,
+ n_layers: usize,
+ n_heads: usize,
+ hidden_dim: usize,
+ activation: HiddenAct,
+ max_position_embeddings: usize,
+ initializer_range: f64,
+ pad_token_id: usize,
+ #[serde(default)]
+ position_embedding_type: PositionEmbeddingType,
+ #[serde(default)]
+ use_cache: bool,
+ model_type: Option<String>,
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ vocab_size: 30522,
+ dim: 768,
+ n_layers: 12,
+ n_heads: 12,
+ hidden_dim: 3072,
+ activation: HiddenAct::Gelu,
+ max_position_embeddings: 512,
+ initializer_range: 0.02,
+ pad_token_id: 0,
+ position_embedding_type: PositionEmbeddingType::Absolute,
+ use_cache: true,
+ model_type: Some("distilbert".to_string()),
+ }
+ }
+}
+
+struct Embeddings {
+ word_embeddings: Embedding,
+ position_embeddings: Embedding,
+ layer_norm: LayerNorm,
+ span: tracing::Span,
+}
+
+impl Embeddings {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let word_embeddings =
+ candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
+ let position_embeddings = candle_nn::embedding(
+ config.max_position_embeddings,
+ config.dim,
+ vb.pp("position_embeddings"),
+ )?;
+ let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
+ Ok(Self {
+ word_embeddings,
+ position_embeddings,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "embeddings"),
+ })
+ }
+
+ fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (_bsize, seq_len) = input_ids.dims2()?;
+ let input_embeddings = self.word_embeddings.forward(input_ids)?;
+ let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
+ let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
+ let embeddings =
+ input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
+
+ let embeddings = self.layer_norm.forward(&embeddings)?;
+ Ok(embeddings)
+ }
+}
+
+struct MultiHeadSelfAttention {
+ q_lin: Linear,
+ k_lin: Linear,
+ v_lin: Linear,
+ out_lin: Linear,
+ n_heads: usize,
+ attention_head_size: usize,
+ span: tracing::Span,
+}
+
+impl MultiHeadSelfAttention {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let attention_head_size = config.dim / config.n_heads;
+ let all_head_size = config.n_heads * attention_head_size;
+ let dim = config.dim;
+ let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
+ let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
+ let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
+ let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
+ Ok(Self {
+ q_lin,
+ k_lin,
+ v_lin,
+ out_lin,
+ n_heads: config.n_heads,
+ attention_head_size,
+ span: tracing::span!(tracing::Level::TRACE, "attention"),
+ })
+ }
+}
+
+impl MultiHeadSelfAttention {
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (bs, q_length, _dim) = hidden_states.dims3()?;
+
+ let dim_per_head = self.attention_head_size;
+ let q = self.q_lin.forward(hidden_states)?;
+ let k = self.k_lin.forward(hidden_states)?;
+ let v = self.v_lin.forward(hidden_states)?;
+
+ let q = q
+ .reshape((bs, q_length, self.n_heads, dim_per_head))?
+ .transpose(1, 2)?;
+ let k = k
+ .reshape((bs, q_length, self.n_heads, dim_per_head))?
+ .transpose(1, 2)?;
+ let v = v
+ .reshape((bs, q_length, self.n_heads, dim_per_head))?
+ .transpose(1, 2)?;
+
+ let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
+ let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
+ let mask = attention_mask.broadcast_as(scores.shape())?;
+
+ let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
+ let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
+
+ let context = weights.matmul(&v.contiguous()?)?;
+ let context = context
+ .transpose(1, 2)?
+ .reshape((bs, q_length, self.n_heads * dim_per_head))?
+ .contiguous()?;
+ let context = self.out_lin.forward(&context)?;
+
+ Ok(context)
+ }
+}
+
+#[allow(clippy::upper_case_acronyms)]
+struct FFN {
+ lin1: Linear,
+ lin2: Linear,
+ activation: HiddenActLayer,
+ span: tracing::Span,
+}
+
+impl FFN {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
+ let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
+ Ok(Self {
+ lin1,
+ lin2,
+ activation: HiddenActLayer::new(config.activation),
+ span: tracing::span!(tracing::Level::TRACE, "ffn"),
+ })
+ }
+}
+
+impl Module for FFN {
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ hidden_states
+ .apply(&self.lin1)?
+ .apply(&self.activation)?
+ .apply(&self.lin2)
+ }
+}
+
+struct TransformerBlock {
+ attention: MultiHeadSelfAttention,
+ sa_layer_norm: LayerNorm,
+ ffn: FFN,
+ output_layer_norm: LayerNorm,
+ span: tracing::Span,
+}
+
+impl TransformerBlock {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
+ let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
+ let ffn = FFN::load(vb.pp("ffn"), config)?;
+ let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
+ Ok(Self {
+ attention,
+ sa_layer_norm,
+ ffn,
+ output_layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "layer"),
+ })
+ }
+}
+
+impl TransformerBlock {
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let sa_output = self.attention.forward(hidden_states, attention_mask)?;
+ // TODO: Support cross-attention?
+ // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
+ // TODO: Support something similar to `apply_chunking_to_forward`?
+ let sa_output = sa_output.broadcast_add(hidden_states)?;
+ let sa_output = self.sa_layer_norm.forward(&sa_output)?;
+
+ let ffn_output = self.ffn.forward(&sa_output)?;
+ let ffn_output = (&ffn_output + sa_output)?;
+ let output = self.output_layer_norm.forward(&ffn_output)?;
+ Ok(output)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
+struct Transformer {
+ layers: Vec<TransformerBlock>,
+ span: tracing::Span,
+}
+
+impl Transformer {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let layers = (0..config.n_layers)
+ .map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
+ .collect::<Result<Vec<_>>>()?;
+ let span = tracing::span!(tracing::Level::TRACE, "encoder");
+ Ok(Transformer { layers, span })
+ }
+}
+
+impl Transformer {
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut hidden_states = hidden_states.clone();
+ // Use a loop rather than a fold as it's easier to modify when adding debug/...
+ for layer in self.layers.iter() {
+ hidden_states = layer.forward(&hidden_states, attention_mask)?;
+ }
+ Ok(hidden_states)
+ }
+}
+
+pub struct DistilBertModel {
+ embeddings: Embeddings,
+ transformer: Transformer,
+ pub device: Device,
+ span: tracing::Span,
+}
+
+impl DistilBertModel {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let (embeddings, transformer) = match (
+ Embeddings::load(vb.pp("embeddings"), config),
+ Transformer::load(vb.pp("transformer"), config),
+ ) {
+ (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
+ (Err(err), _) | (_, Err(err)) => {
+ if let Some(model_type) = &config.model_type {
+ if let (Ok(embeddings), Ok(encoder)) = (
+ Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
+ Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
+ ) {
+ (embeddings, encoder)
+ } else {
+ return Err(err);
+ }
+ } else {
+ return Err(err);
+ }
+ }
+ };
+ Ok(Self {
+ embeddings,
+ transformer,
+ device: vb.device().clone(),
+ span: tracing::span!(tracing::Level::TRACE, "model"),
+ })
+ }
+
+ pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let embedding_output = self.embeddings.forward(input_ids)?;
+ let sequence_output = self
+ .transformer
+ .forward(&embedding_output, attention_mask)?;
+ Ok(sequence_output)
+ }
+}