summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bert/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/bert/main.rs')
-rw-r--r--candle-examples/examples/bert/main.rs12
1 files changed, 2 insertions, 10 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index d0d600ee..d8f6921e 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -1,5 +1,3 @@
-#![allow(dead_code)]
-
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@@ -86,7 +84,7 @@ impl Default for Config {
}
impl Config {
- fn all_mini_lm_l6_v2() -> Self {
+ fn _all_mini_lm_l6_v2() -> Self {
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
Self {
vocab_size: 30522,
@@ -121,6 +119,7 @@ fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
}
struct Dropout {
+ #[allow(dead_code)]
pr: f64,
}
@@ -156,8 +155,6 @@ struct BertEmbeddings {
token_type_embeddings: Embedding,
layer_norm: LayerNorm,
dropout: Dropout,
- position_ids: Tensor,
- token_type_ids: Tensor,
}
impl BertEmbeddings {
@@ -182,17 +179,12 @@ impl BertEmbeddings {
config.layer_norm_eps,
vb.pp("LayerNorm"),
)?;
- let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
- let position_ids = Tensor::new(&position_ids[..], vb.device())?.unsqueeze(0)?;
- let token_type_ids = position_ids.zeros_like()?;
Ok(Self {
word_embeddings,
position_embeddings: Some(position_embeddings),
token_type_embeddings,
layer_norm,
dropout: Dropout::new(config.hidden_dropout_prob),
- position_ids,
- token_type_ids,
})
}