summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-21 10:19:23 +0100
committerGitHub <noreply@github.com>2023-10-21 10:19:23 +0100
commit94e3373883caaa7442201dac25abe16b4469f9bd (patch)
treeb1fc4878f2e9c8d6fac74fa790d0e36b729c82df
parent34d9e9174824cc0656e083364fe68b85666843e0 (diff)
downloadcandle-94e3373883caaa7442201dac25abe16b4469f9bd.tar.gz
candle-94e3373883caaa7442201dac25abe16b4469f9bd.tar.bz2
candle-94e3373883caaa7442201dac25abe16b4469f9bd.zip
Blip forward pass (#1141)
* More forward methods for the blip model. * Blipping continues.
-rw-r--r--candle-transformers/src/models/blip.rs47
1 files changed, 42 insertions, 5 deletions
diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs
index 4c2ca44d..dd1bcd48 100644
--- a/candle-transformers/src/models/blip.rs
+++ b/candle-transformers/src/models/blip.rs
@@ -104,10 +104,8 @@ impl Attention {
num_heads,
})
}
-}
-impl Module for Attention {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result<Tensor> {
let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
let mixed_qkv = xs
.apply(&self.qkv)?
@@ -119,6 +117,10 @@ impl Module for Attention {
let attention_scores = query.matmul(&key.t()?)?;
let attention_scores = (attention_scores * self.scale)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
+ let attention_probs = match attn_mask {
+ None => attention_probs,
+ Some(attn_mask) => (attention_probs * attn_mask)?,
+ };
attention_probs
.matmul(&value)?
.permute((0, 2, 1, 3))?
@@ -178,10 +180,15 @@ impl EncoderLayer {
})
}
- fn forward(&self, xs: &Tensor, attention_mask: Tensor) -> Result<Tensor> {
+ fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.layer_norm1)?;
- todo!()
+ let xs = self.self_attn.forward(&xs, attention_mask)?;
+ let xs = (xs + residual)?;
+
+ let residual = &xs;
+ let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
+ xs + residual
}
}
@@ -199,6 +206,14 @@ impl Encoder {
}
Ok(Self { layers })
}
+
+ fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, attention_mask)?
+ }
+ Ok(xs)
+ }
}
#[derive(Debug, Clone)]
@@ -222,6 +237,19 @@ impl VisionModel {
}
}
+impl Module for VisionModel {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.apply(&self.embeddings)?;
+ let encoder_outputs = self.encoder.forward(&xs, None)?;
+ let last_hidden_state = encoder_outputs.get(0)?;
+ last_hidden_state
+ .apply(&self.post_layernorm)?
+ .narrow(1, 0, 1)?
+ .squeeze(1)?
+ .apply(&self.post_layernorm)
+ }
+}
+
#[derive(Debug, Clone)]
struct BlipForConditionalGeneration {
vision_model: VisionModel,
@@ -238,4 +266,13 @@ impl BlipForConditionalGeneration {
text_decoder,
})
}
+
+ fn forward(
+ &self,
+ pixel_values: &Tensor,
+ input_ids: Option<&Tensor>,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ todo!()
+ }
}