summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Sing <32938975+singjc@users.noreply.github.com>2024-12-04 15:22:30 -0500
committerGitHub <noreply@github.com>2024-12-04 21:22:30 +0100
commit1807be84f4d9e388b19710a9282eb6501ce55f80 (patch)
treec531dde68550403295cbc5c1296034f5cfcb2f6d
parent145aa7193c4e658b184f52706574cc9f115e4674 (diff)
downloadcandle-1807be84f4d9e388b19710a9282eb6501ce55f80.tar.gz
candle-1807be84f4d9e388b19710a9282eb6501ce55f80.tar.bz2
candle-1807be84f4d9e388b19710a9282eb6501ce55f80.zip
Change/bert encoder public (#2658)
* change: BertEncoder struct to public * change: make certain fields in Config struct public * change: all fields in bert config struct to be public * change: add clone to bert encoder and others * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
-rw-r--r--candle-transformers/src/models/bert.rs51
1 files changed, 30 insertions, 21 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index da873416..0ff62c4f 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -22,6 +22,7 @@ pub enum HiddenAct {
Relu,
}
+#[derive(Clone)]
struct HiddenActLayer {
act: HiddenAct,
span: tracing::Span,
@@ -46,7 +47,7 @@ impl HiddenActLayer {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
-enum PositionEmbeddingType {
+pub enum PositionEmbeddingType {
#[default]
Absolute,
}
@@ -54,24 +55,24 @@ enum PositionEmbeddingType {
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
- vocab_size: usize,
- hidden_size: usize,
- num_hidden_layers: usize,
- num_attention_heads: usize,
- intermediate_size: usize,
+ pub vocab_size: usize,
+ pub hidden_size: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub intermediate_size: usize,
pub hidden_act: HiddenAct,
- hidden_dropout_prob: f64,
- max_position_embeddings: usize,
- type_vocab_size: usize,
- initializer_range: f64,
- layer_norm_eps: f64,
- pad_token_id: usize,
+ pub hidden_dropout_prob: f64,
+ pub max_position_embeddings: usize,
+ pub type_vocab_size: usize,
+ pub initializer_range: f64,
+ pub layer_norm_eps: f64,
+ pub pad_token_id: usize,
#[serde(default)]
- position_embedding_type: PositionEmbeddingType,
+ pub position_embedding_type: PositionEmbeddingType,
#[serde(default)]
- use_cache: bool,
- classifier_dropout: Option<f64>,
- model_type: Option<String>,
+ pub use_cache: bool,
+ pub classifier_dropout: Option<f64>,
+ pub model_type: Option<String>,
}
impl Default for Config {
@@ -121,6 +122,7 @@ impl Config {
}
}
+#[derive(Clone)]
struct Dropout {
#[allow(dead_code)]
pr: f64,
@@ -199,6 +201,7 @@ impl BertEmbeddings {
}
}
+#[derive(Clone)]
struct BertSelfAttention {
query: Linear,
key: Linear,
@@ -266,6 +269,7 @@ impl BertSelfAttention {
}
}
+#[derive(Clone)]
struct BertSelfOutput {
dense: Linear,
layer_norm: LayerNorm,
@@ -299,6 +303,7 @@ impl BertSelfOutput {
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
+#[derive(Clone)]
struct BertAttention {
self_attention: BertSelfAttention,
self_output: BertSelfOutput,
@@ -325,6 +330,7 @@ impl BertAttention {
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
+#[derive(Clone)]
struct BertIntermediate {
dense: Linear,
intermediate_act: HiddenActLayer,
@@ -352,6 +358,7 @@ impl Module for BertIntermediate {
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
+#[derive(Clone)]
struct BertOutput {
dense: Linear,
layer_norm: LayerNorm,
@@ -385,7 +392,8 @@ impl BertOutput {
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
-struct BertLayer {
+#[derive(Clone)]
+pub struct BertLayer {
attention: BertAttention,
intermediate: BertIntermediate,
output: BertOutput,
@@ -420,13 +428,14 @@ impl BertLayer {
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
-struct BertEncoder {
- layers: Vec<BertLayer>,
+#[derive(Clone)]
+pub struct BertEncoder {
+ pub layers: Vec<BertLayer>,
span: tracing::Span,
}
impl BertEncoder {
- fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
@@ -434,7 +443,7 @@ impl BertEncoder {
Ok(BertEncoder { layers, span })
}
- fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ pub fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...