summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/bert.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/bert.rs')
-rw-r--r--candle-transformers/src/models/bert.rs97
1 files changed, 97 insertions, 0 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index 354048de..bdc0385d 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -504,3 +504,100 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
}
+
+//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
+struct BertPredictionHeadTransform {
+ dense: Linear,
+ activation: HiddenActLayer,
+ layer_norm: LayerNorm,
+}
+
+impl BertPredictionHeadTransform {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
+ let activation = HiddenActLayer::new(config.hidden_act);
+ let layer_norm = layer_norm(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("LayerNorm"),
+ )?;
+ Ok(Self {
+ dense,
+ activation,
+ layer_norm,
+ })
+ }
+}
+
+impl Module for BertPredictionHeadTransform {
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let hidden_states = self
+ .activation
+ .forward(&self.dense.forward(hidden_states)?)?;
+ self.layer_norm.forward(&hidden_states)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
+pub struct BertLMPredictionHead {
+ transform: BertPredictionHeadTransform,
+ decoder: Linear,
+}
+
+impl BertLMPredictionHead {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?;
+ let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?;
+ Ok(Self { transform, decoder })
+ }
+}
+
+impl Module for BertLMPredictionHead {
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ self.decoder
+ .forward(&self.transform.forward(hidden_states)?)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
+pub struct BertOnlyMLMHead {
+ predictions: BertLMPredictionHead,
+}
+
+impl BertOnlyMLMHead {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?;
+ Ok(Self { predictions })
+ }
+}
+
+impl Module for BertOnlyMLMHead {
+ fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
+ self.predictions.forward(sequence_output)
+ }
+}
+
+pub struct BertForMaskedLM {
+ bert: BertModel,
+ cls: BertOnlyMLMHead,
+}
+
+impl BertForMaskedLM {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let bert = BertModel::load(vb.pp("bert"), config)?;
+ let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?;
+ Ok(Self { bert, cls })
+ }
+
+ pub fn forward(
+ &self,
+ input_ids: &Tensor,
+ token_type_ids: &Tensor,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let sequence_output = self
+ .bert
+ .forward(input_ids, token_type_ids, attention_mask)?;
+ self.cls.forward(&sequence_output)
+ }
+}