summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-22 06:47:40 +0100
committerGitHub <noreply@github.com>2023-10-22 06:47:40 +0100
commit5b32c2a41e1b4897d6f8b70b1876a3984a31d94e (patch)
treead54fa75ac8ec08c3a82d44cdfaac5054352d365 /candle-transformers
parent3115fe42e4b203b02219eaf85b749f6710d0de3e (diff)
downloadcandle-5b32c2a41e1b4897d6f8b70b1876a3984a31d94e.tar.gz
candle-5b32c2a41e1b4897d6f8b70b1876a3984a31d94e.tar.bz2
candle-5b32c2a41e1b4897d6f8b70b1876a3984a31d94e.zip
Remove the unused pragma and properly apply the bias. (#1147)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/blip.rs7
-rw-r--r--candle-transformers/src/models/blip_text.rs22
-rw-r--r--candle-transformers/src/models/with_tracing.rs8
3 files changed, 15 insertions, 22 deletions
diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs
index 1b4f9008..daa96926 100644
--- a/candle-transformers/src/models/blip.rs
+++ b/candle-transformers/src/models/blip.rs
@@ -1,4 +1,3 @@
-#![allow(unused)]
use super::blip_text;
use super::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, Result, Tensor, D};
@@ -65,7 +64,6 @@ struct VisionEmbeddings {
class_embedding: Tensor,
patch_embedding: Conv2d,
position_embedding: Tensor,
- num_positions: usize,
}
impl VisionEmbeddings {
@@ -91,7 +89,6 @@ impl VisionEmbeddings {
class_embedding,
patch_embedding,
position_embedding,
- num_positions,
})
}
}
@@ -117,8 +114,6 @@ struct Attention {
qkv: Linear,
projection: Linear,
scale: f64,
- embed_dim: usize,
- head_dim: usize,
num_heads: usize,
}
@@ -134,8 +129,6 @@ impl Attention {
qkv,
projection,
scale,
- embed_dim,
- head_dim,
num_heads,
})
}
diff --git a/candle-transformers/src/models/blip_text.rs b/candle-transformers/src/models/blip_text.rs
index 6db2b9d8..f1a38f11 100644
--- a/candle-transformers/src/models/blip_text.rs
+++ b/candle-transformers/src/models/blip_text.rs
@@ -1,5 +1,4 @@
-#![allow(unused)]
-use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
+use super::with_tracing::{linear, Embedding, Linear};
use candle::{Module, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
@@ -63,7 +62,6 @@ struct TextSelfAttention {
query: Linear,
key: Linear,
value: Linear,
- all_head_size: usize,
attention_head_size: usize,
num_attention_heads: usize,
attention_scale: f64,
@@ -87,7 +85,6 @@ impl TextSelfAttention {
query,
key,
value,
- all_head_size,
attention_head_size,
num_attention_heads,
attention_scale,
@@ -301,12 +298,12 @@ impl TextEncoder {
}
#[derive(Debug, Clone)]
-struct TextPooler {
+pub struct TextPooler {
dense: Linear,
}
impl TextPooler {
- fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
Ok(Self { dense })
}
@@ -352,19 +349,15 @@ impl Module for TextPredictionHeadTransform {
struct TextLMPredictionHead {
transform: TextPredictionHeadTransform,
decoder: Linear,
- bias: Tensor,
}
impl TextLMPredictionHead {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
- let decoder = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
+ let weight = vb.get((cfg.vocab_size, cfg.hidden_size), "decoder.weight")?;
let bias = vb.get(cfg.vocab_size, "bias")?;
- Ok(Self {
- transform,
- decoder,
- bias,
- })
+ let decoder = Linear::from_weights(weight, Some(bias));
+ Ok(Self { transform, decoder })
}
}
@@ -396,7 +389,7 @@ impl Module for TextOnlyMLMHead {
struct TextModel {
embeddings: TextEmbeddings,
encoder: TextEncoder,
- pooler: Option<TextPooler>,
+ // We do not need the pooler for caption generation
}
impl TextModel {
@@ -406,7 +399,6 @@ impl TextModel {
Ok(Self {
embeddings,
encoder,
- pooler: None,
})
}
diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs
index 69654139..39258085 100644
--- a/candle-transformers/src/models/with_tracing.rs
+++ b/candle-transformers/src/models/with_tracing.rs
@@ -32,6 +32,14 @@ pub struct Linear {
span: tracing::Span,
}
+impl Linear {
+ pub fn from_weights(weights: Tensor, bias: Option<Tensor>) -> Self {
+ let inner = candle_nn::Linear::new(weights, bias);
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Self { inner, span }
+ }
+}
+
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");