summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/blip/main.rs54
-rw-r--r--candle-transformers/src/models/blip.rs108
-rw-r--r--candle-transformers/src/models/blip_text.rs106
3 files changed, 223 insertions, 45 deletions
diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs
new file mode 100644
index 00000000..82355778
--- /dev/null
+++ b/candle-examples/examples/blip/main.rs
@@ -0,0 +1,54 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use clap::Parser;
+
+use candle::DType;
+use candle_nn::VarBuilder;
+use candle_transformers::models::blip;
+
+#[derive(Parser)]
+struct Args {
+ #[arg(long)]
+ model: Option<String>,
+
+ #[arg(long)]
+ image: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+}
+
+pub fn main() -> anyhow::Result<()> {
+ let args = Args::parse();
+
+ let device = candle_examples::device(args.cpu)?;
+
+ let image = candle_examples::imagenet::load_image224(args.image)?;
+ println!("loaded image {image:?}");
+
+ let model_file = match args.model {
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.repo(hf_hub::Repo::with_revision(
+ "Salesforce/blip-image-captioning-large".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/18".to_string(),
+ ));
+ api.get("model.safetensors")?
+ }
+ Some(model) => model.into(),
+ };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
+ let config = blip::Config::image_captioning_large();
+ let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
+ println!("model built");
+ // TODO: Maybe add support for the conditional prompt.
+ let out = model.generate(&image.unsqueeze(0)?, None, None)?;
+ println!(">>>\n{out}");
+ Ok(())
+}
diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs
index dd1bcd48..b2be112e 100644
--- a/candle-transformers/src/models/blip.rs
+++ b/candle-transformers/src/models/blip.rs
@@ -5,24 +5,59 @@ use candle::{Module, Result, Tensor, D};
use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder};
#[derive(Debug, Clone)]
-struct VisionConfig {
- hidden_size: usize,
- intermediate_size: usize,
- projection_dim: usize,
- num_hidden_layers: usize,
- num_attention_heads: usize,
- image_size: usize,
- patch_size: usize,
- hidden_act: candle_nn::Activation,
- layer_norm_eps: f64,
+pub struct VisionConfig {
+ pub hidden_size: usize,
+ pub intermediate_size: usize,
+ pub projection_dim: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub image_size: usize,
+ pub patch_size: usize,
+ pub hidden_act: candle_nn::Activation,
+ pub layer_norm_eps: f64,
}
#[derive(Debug, Clone)]
-struct Config {
- text_config: blip_text::Config,
- vision_config: VisionConfig,
- projection_dim: usize,
- image_text_hidden_size: usize,
+pub struct Config {
+ pub text_config: blip_text::Config,
+ pub vision_config: VisionConfig,
+ pub projection_dim: usize,
+ pub image_text_hidden_size: usize,
+}
+
+impl Config {
+ pub fn image_captioning_large() -> Self {
+ let text_config = blip_text::Config {
+ vocab_size: 30524,
+ hidden_size: 768,
+ encoder_hidden_size: 1024,
+ intermediate_size: 3072,
+ projection_dim: 768,
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ max_position_embeddings: 512,
+ hidden_act: candle_nn::Activation::Gelu,
+ layer_norm_eps: 1e-12,
+ is_decoder: true,
+ };
+ let vision_config = VisionConfig {
+ hidden_size: 1024,
+ intermediate_size: 4096,
+ projection_dim: 512,
+ num_hidden_layers: 24,
+ num_attention_heads: 16,
+ image_size: 384,
+ patch_size: 16,
+ hidden_act: candle_nn::Activation::Gelu,
+ layer_norm_eps: 1e-5,
+ };
+ Self {
+ text_config,
+ vision_config,
+ projection_dim: 512,
+ image_text_hidden_size: 256,
+ }
+ }
}
#[derive(Debug, Clone)]
@@ -200,6 +235,7 @@ struct Encoder {
impl Encoder {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb = vb.pp("layers");
for i in 0..cfg.num_hidden_layers {
let layer = EncoderLayer::new(cfg, vb.pp(i))?;
layers.push(layer)
@@ -217,7 +253,7 @@ impl Encoder {
}
#[derive(Debug, Clone)]
-struct VisionModel {
+pub struct VisionModel {
embeddings: VisionEmbeddings,
encoder: Encoder,
post_layernorm: LayerNorm,
@@ -241,23 +277,19 @@ impl Module for VisionModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embeddings)?;
let encoder_outputs = self.encoder.forward(&xs, None)?;
- let last_hidden_state = encoder_outputs.get(0)?;
- last_hidden_state
- .apply(&self.post_layernorm)?
- .narrow(1, 0, 1)?
- .squeeze(1)?
- .apply(&self.post_layernorm)
+ // Return the last hidden state rather than pooled outputs.
+ encoder_outputs.apply(&self.post_layernorm)
}
}
#[derive(Debug, Clone)]
-struct BlipForConditionalGeneration {
+pub struct BlipForConditionalGeneration {
vision_model: VisionModel,
text_decoder: blip_text::TextLMHeadModel,
}
impl BlipForConditionalGeneration {
- fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
let text_decoder =
blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
@@ -267,12 +299,38 @@ impl BlipForConditionalGeneration {
})
}
- fn forward(
+ pub fn vision_model(&self) -> &VisionModel {
+ &self.vision_model
+ }
+
+ pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel {
+ &self.text_decoder
+ }
+
+ pub fn generate(
&self,
pixel_values: &Tensor,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
+ let image_embeds = pixel_values.apply(&self.vision_model)?;
+ let b_size = image_embeds.dim(0)?;
+ if b_size > 1 {
+ candle::bail!("only a batch size of 1 is supported")
+ }
+ let mut logits_processor = crate::generation::LogitsProcessor::new(1337, None, None);
+ let mut token_ids = vec![30522u32];
+ for i in 0..1000 {
+ let input_ids =
+ Tensor::new(token_ids.as_slice(), pixel_values.device())?.broadcast_left(b_size)?;
+ let logits = self.text_decoder.forward(&input_ids, &image_embeds)?;
+ println!("{logits:?}");
+ let logits = logits.squeeze(0)?;
+ let logits = logits.get(logits.dim(0)? - 1)?;
+ let token = logits_processor.sample(&logits)?;
+ println!("{token}");
+ token_ids.push(token)
+ }
todo!()
}
}
diff --git a/candle-transformers/src/models/blip_text.rs b/candle-transformers/src/models/blip_text.rs
index 8b4fb4d1..8d0712c0 100644
--- a/candle-transformers/src/models/blip_text.rs
+++ b/candle-transformers/src/models/blip_text.rs
@@ -5,17 +5,17 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
#[derive(Debug, Clone)]
pub struct Config {
- vocab_size: usize,
- hidden_size: usize,
- encoder_hidden_size: usize,
- intermediate_size: usize,
- projection_dim: usize,
- num_hidden_layers: usize,
- num_attention_heads: usize,
- max_position_embeddings: usize,
- hidden_act: candle_nn::Activation,
- layer_norm_eps: f64,
- is_decoder: bool,
+ pub vocab_size: usize,
+ pub hidden_size: usize,
+ pub encoder_hidden_size: usize,
+ pub intermediate_size: usize,
+ pub projection_dim: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub max_position_embeddings: usize,
+ pub hidden_act: candle_nn::Activation,
+ pub layer_norm_eps: f64,
+ pub is_decoder: bool,
}
#[derive(Debug, Clone)]
@@ -47,6 +47,17 @@ impl TextEmbeddings {
}
}
+impl Module for TextEmbeddings {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let seq_len = xs.dim(1)?;
+ // Use past_key_values_length if we add a kv cache.
+ let position_ids = self.position_ids.narrow(1, 0, seq_len)?;
+ let embeddings = self.word_embedddings.forward(xs)?;
+ let position_embeddings = self.position_embeddings.forward(&position_ids)?;
+ (embeddings + position_embeddings)?.apply(&self.layer_norm)
+ }
+}
+
#[derive(Debug, Clone)]
struct TextSelfAttention {
query: Linear,
@@ -55,6 +66,7 @@ struct TextSelfAttention {
all_head_size: usize,
attention_head_size: usize,
num_attention_heads: usize,
+ attention_scale: f64,
}
impl TextSelfAttention {
@@ -70,6 +82,7 @@ impl TextSelfAttention {
};
let key = linear(in_size, all_head_size, vb.pp("key"))?;
let value = linear(in_size, all_head_size, vb.pp("value"))?;
+ let attention_scale = 1f64 / (attention_head_size as f64).sqrt();
Ok(Self {
query,
key,
@@ -77,6 +90,7 @@ impl TextSelfAttention {
all_head_size,
attention_head_size,
num_attention_heads,
+ attention_scale,
})
}
@@ -90,6 +104,35 @@ impl TextSelfAttention {
))?
.permute((0, 2, 1, 3))
}
+
+ fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> {
+ let query = self
+ .transpose_for_scores(&self.query.forward(xs)?)?
+ .contiguous()?;
+ let (key, value) = match encoder_hidden_states {
+ None => {
+ let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
+ let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
+ // TODO: kv cache
+ (key, value)
+ }
+ Some(xs) => {
+ let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
+ let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
+ // no kv-cache in this case, but the results could probably be memoized.
+ (key, value)
+ }
+ };
+ let key = key.contiguous()?;
+ let value = value.contiguous()?;
+ let attention_scores = query.matmul(&key.t()?)?;
+ let attention_scores = (attention_scores * self.attention_scale)?;
+ let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
+ attention_probs
+ .matmul(&value)?
+ .permute((0, 2, 1, 3))?
+ .flatten_from(D::Minus2)
+ }
}
#[derive(Debug, Clone)]
@@ -122,6 +165,11 @@ impl TextAttention {
let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
Ok(Self { self_, output })
}
+
+ fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> {
+ let self_outputs = self.self_.forward(xs, encoder_hidden_states)?;
+ self.output.forward(&self_outputs, xs)
+ }
}
#[derive(Debug, Clone)]
@@ -176,7 +224,7 @@ impl TextLayer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
let cross_attention = if cfg.is_decoder {
- Some(TextAttention::new(cfg, true, vb.pp("attention"))?)
+ Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?)
} else {
None
};
@@ -189,11 +237,15 @@ impl TextLayer {
output,
})
}
-}
-impl Module for TextLayer {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- todo!()
+ fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
+ let attention_output = self.attention.forward(xs, None)?;
+ let attention_output = match &self.cross_attention {
+ Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states))?,
+ None => candle::bail!("expected some cross-attn"),
+ };
+ let intermediate_output = self.intermediate.forward(&attention_output)?;
+ self.output.forward(&intermediate_output, &attention_output)
}
}
@@ -212,13 +264,11 @@ impl TextEncoder {
}
Ok(Self { layers })
}
-}
-impl Module for TextEncoder {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
- xs = xs.apply(layer)?
+ xs = layer.forward(&xs, encoder_hidden_states)?
}
Ok(xs)
}
@@ -333,6 +383,15 @@ impl TextModel {
pooler: None,
})
}
+
+ fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
+ let embedding_output = self.embeddings.forward(input_ids)?;
+ let sequence_output = self
+ .encoder
+ .forward(&embedding_output, encoder_hidden_states)?;
+ // We're interested in the sequence-output rather than the pooled-output.
+ Ok(sequence_output)
+ }
}
#[derive(Debug, Clone)]
@@ -347,4 +406,11 @@ impl TextLMHeadModel {
let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
Ok(Self { bert, cls })
}
+
+ pub fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
+ let sequence_output = self.bert.forward(input_ids, encoder_hidden_states)?;
+ let prediction_scores = self.cls.forward(&sequence_output)?;
+ // return_logits is false so we don't discard the last sequence element.
+ Ok(prediction_scores)
+ }
}