diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-22 06:47:40 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-22 06:47:40 +0100 |
commit | 5b32c2a41e1b4897d6f8b70b1876a3984a31d94e (patch) | |
tree | ad54fa75ac8ec08c3a82d44cdfaac5054352d365 /candle-transformers | |
parent | 3115fe42e4b203b02219eaf85b749f6710d0de3e (diff) | |
download | candle-5b32c2a41e1b4897d6f8b70b1876a3984a31d94e.tar.gz candle-5b32c2a41e1b4897d6f8b70b1876a3984a31d94e.tar.bz2 candle-5b32c2a41e1b4897d6f8b70b1876a3984a31d94e.zip |
Remove the unused pragma and properly apply the bias. (#1147)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/blip.rs | 7 | ||||
-rw-r--r-- | candle-transformers/src/models/blip_text.rs | 22 | ||||
-rw-r--r-- | candle-transformers/src/models/with_tracing.rs | 8 |
3 files changed, 15 insertions, 22 deletions
diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs index 1b4f9008..daa96926 100644 --- a/candle-transformers/src/models/blip.rs +++ b/candle-transformers/src/models/blip.rs @@ -1,4 +1,3 @@ -#![allow(unused)] use super::blip_text; use super::with_tracing::{conv2d, linear, Conv2d, Linear}; use candle::{Module, Result, Tensor, D}; @@ -65,7 +64,6 @@ struct VisionEmbeddings { class_embedding: Tensor, patch_embedding: Conv2d, position_embedding: Tensor, - num_positions: usize, } impl VisionEmbeddings { @@ -91,7 +89,6 @@ impl VisionEmbeddings { class_embedding, patch_embedding, position_embedding, - num_positions, }) } } @@ -117,8 +114,6 @@ struct Attention { qkv: Linear, projection: Linear, scale: f64, - embed_dim: usize, - head_dim: usize, num_heads: usize, } @@ -134,8 +129,6 @@ impl Attention { qkv, projection, scale, - embed_dim, - head_dim, num_heads, }) } diff --git a/candle-transformers/src/models/blip_text.rs b/candle-transformers/src/models/blip_text.rs index 6db2b9d8..f1a38f11 100644 --- a/candle-transformers/src/models/blip_text.rs +++ b/candle-transformers/src/models/blip_text.rs @@ -1,5 +1,4 @@ -#![allow(unused)] -use super::with_tracing::{linear, linear_no_bias, Embedding, Linear}; +use super::with_tracing::{linear, Embedding, Linear}; use candle::{Module, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; @@ -63,7 +62,6 @@ struct TextSelfAttention { query: Linear, key: Linear, value: Linear, - all_head_size: usize, attention_head_size: usize, num_attention_heads: usize, attention_scale: f64, @@ -87,7 +85,6 @@ impl TextSelfAttention { query, key, value, - all_head_size, attention_head_size, num_attention_heads, attention_scale, @@ -301,12 +298,12 @@ impl TextEncoder { } #[derive(Debug, Clone)] -struct TextPooler { +pub struct TextPooler { dense: Linear, } impl TextPooler { - fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; Ok(Self { dense }) } @@ -352,19 +349,15 @@ impl Module for TextPredictionHeadTransform { struct TextLMPredictionHead { transform: TextPredictionHeadTransform, decoder: Linear, - bias: Tensor, } impl TextLMPredictionHead { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?; - let decoder = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?; + let weight = vb.get((cfg.vocab_size, cfg.hidden_size), "decoder.weight")?; let bias = vb.get(cfg.vocab_size, "bias")?; - Ok(Self { - transform, - decoder, - bias, - }) + let decoder = Linear::from_weights(weight, Some(bias)); + Ok(Self { transform, decoder }) } } @@ -396,7 +389,7 @@ impl Module for TextOnlyMLMHead { struct TextModel { embeddings: TextEmbeddings, encoder: TextEncoder, - pooler: Option<TextPooler>, + // We do not need the pooler for caption generation } impl TextModel { @@ -406,7 +399,6 @@ impl TextModel { Ok(Self { embeddings, encoder, - pooler: None, }) } diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 69654139..39258085 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -32,6 +32,14 @@ pub struct Linear { span: tracing::Span, } +impl Linear { + pub fn from_weights(weights: Tensor, bias: Option<Tensor>) -> Self { + let inner = candle_nn::Linear::new(weights, bias); + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Self { inner, span } + } +} + pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> { let inner = candle_nn::linear(d1, d2, vb)?; let span = tracing::span!(tracing::Level::TRACE, "linear"); |