diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-21 20:05:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-21 20:05:02 +0100 |
commit | 0d9bb4eb18b01a06975e2d6eed2cec6735ec8400 (patch) | |
tree | b90472358623a83cfd611d39af1841285fbf5dff | |
parent | e8f760ee44ad4b1f9f3606e36a1966df8509203b (diff) | |
download | candle-0d9bb4eb18b01a06975e2d6eed2cec6735ec8400.tar.gz candle-0d9bb4eb18b01a06975e2d6eed2cec6735ec8400.tar.bz2 candle-0d9bb4eb18b01a06975e2d6eed2cec6735ec8400.zip |
Add the blip example. (#1144)
* Add the blip example.
* Tweak the example.
* Implement the cross-attn logic.
* Fix some shape mismatches.
* Get some logits out.
* Get some caption to be generated.
-rw-r--r-- | candle-examples/examples/blip/main.rs | 54 | ||||
-rw-r--r-- | candle-transformers/src/models/blip.rs | 108 | ||||
-rw-r--r-- | candle-transformers/src/models/blip_text.rs | 106 |
3 files changed, 223 insertions, 45 deletions
diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs new file mode 100644 index 00000000..82355778 --- /dev/null +++ b/candle-examples/examples/blip/main.rs @@ -0,0 +1,54 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::Parser; + +use candle::DType; +use candle_nn::VarBuilder; +use candle_transformers::models::blip; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option<String>, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image224(args.image)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.repo(hf_hub::Repo::with_revision( + "Salesforce/blip-image-captioning-large".to_string(), + hf_hub::RepoType::Model, + "refs/pr/18".to_string(), + )); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let config = blip::Config::image_captioning_large(); + let model = blip::BlipForConditionalGeneration::new(&config, vb)?; + println!("model built"); + // TODO: Maybe add support for the conditional prompt. + let out = model.generate(&image.unsqueeze(0)?, None, None)?; + println!(">>>\n{out}"); + Ok(()) +} diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs index dd1bcd48..b2be112e 100644 --- a/candle-transformers/src/models/blip.rs +++ b/candle-transformers/src/models/blip.rs @@ -5,24 +5,59 @@ use candle::{Module, Result, Tensor, D}; use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder}; #[derive(Debug, Clone)] -struct VisionConfig { - hidden_size: usize, - intermediate_size: usize, - projection_dim: usize, - num_hidden_layers: usize, - num_attention_heads: usize, - image_size: usize, - patch_size: usize, - hidden_act: candle_nn::Activation, - layer_norm_eps: f64, +pub struct VisionConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub projection_dim: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub image_size: usize, + pub patch_size: usize, + pub hidden_act: candle_nn::Activation, + pub layer_norm_eps: f64, } #[derive(Debug, Clone)] -struct Config { - text_config: blip_text::Config, - vision_config: VisionConfig, - projection_dim: usize, - image_text_hidden_size: usize, +pub struct Config { + pub text_config: blip_text::Config, + pub vision_config: VisionConfig, + pub projection_dim: usize, + pub image_text_hidden_size: usize, +} + +impl Config { + pub fn image_captioning_large() -> Self { + let text_config = blip_text::Config { + vocab_size: 30524, + hidden_size: 768, + encoder_hidden_size: 1024, + intermediate_size: 3072, + projection_dim: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + max_position_embeddings: 512, + hidden_act: candle_nn::Activation::Gelu, + layer_norm_eps: 1e-12, + is_decoder: true, + }; + let vision_config = VisionConfig { + hidden_size: 1024, + intermediate_size: 4096, + projection_dim: 512, + num_hidden_layers: 24, + num_attention_heads: 16, + image_size: 384, + patch_size: 16, + hidden_act: candle_nn::Activation::Gelu, + layer_norm_eps: 1e-5, + }; + Self { + text_config, + vision_config, + projection_dim: 512, + image_text_hidden_size: 256, + } + } } #[derive(Debug, Clone)] @@ -200,6 +235,7 @@ struct Encoder { impl Encoder { fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> { let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb = vb.pp("layers"); for i in 0..cfg.num_hidden_layers { let layer = EncoderLayer::new(cfg, vb.pp(i))?; layers.push(layer) @@ -217,7 +253,7 @@ impl Encoder { } #[derive(Debug, Clone)] -struct VisionModel { +pub struct VisionModel { embeddings: VisionEmbeddings, encoder: Encoder, post_layernorm: LayerNorm, @@ -241,23 +277,19 @@ impl Module for VisionModel { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let xs = xs.apply(&self.embeddings)?; let encoder_outputs = self.encoder.forward(&xs, None)?; - let last_hidden_state = encoder_outputs.get(0)?; - last_hidden_state - .apply(&self.post_layernorm)? - .narrow(1, 0, 1)? - .squeeze(1)? - .apply(&self.post_layernorm) + // Return the last hidden state rather than pooled outputs. + encoder_outputs.apply(&self.post_layernorm) } } #[derive(Debug, Clone)] -struct BlipForConditionalGeneration { +pub struct BlipForConditionalGeneration { vision_model: VisionModel, text_decoder: blip_text::TextLMHeadModel, } impl BlipForConditionalGeneration { - fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?; let text_decoder = blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?; @@ -267,12 +299,38 @@ impl BlipForConditionalGeneration { }) } - fn forward( + pub fn vision_model(&self) -> &VisionModel { + &self.vision_model + } + + pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel { + &self.text_decoder + } + + pub fn generate( &self, pixel_values: &Tensor, input_ids: Option<&Tensor>, attention_mask: Option<&Tensor>, ) -> Result<Tensor> { + let image_embeds = pixel_values.apply(&self.vision_model)?; + let b_size = image_embeds.dim(0)?; + if b_size > 1 { + candle::bail!("only a batch size of 1 is supported") + } + let mut logits_processor = crate::generation::LogitsProcessor::new(1337, None, None); + let mut token_ids = vec![30522u32]; + for i in 0..1000 { + let input_ids = + Tensor::new(token_ids.as_slice(), pixel_values.device())?.broadcast_left(b_size)?; + let logits = self.text_decoder.forward(&input_ids, &image_embeds)?; + println!("{logits:?}"); + let logits = logits.squeeze(0)?; + let logits = logits.get(logits.dim(0)? - 1)?; + let token = logits_processor.sample(&logits)?; + println!("{token}"); + token_ids.push(token) + } todo!() } } diff --git a/candle-transformers/src/models/blip_text.rs b/candle-transformers/src/models/blip_text.rs index 8b4fb4d1..8d0712c0 100644 --- a/candle-transformers/src/models/blip_text.rs +++ b/candle-transformers/src/models/blip_text.rs @@ -5,17 +5,17 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder}; #[derive(Debug, Clone)] pub struct Config { - vocab_size: usize, - hidden_size: usize, - encoder_hidden_size: usize, - intermediate_size: usize, - projection_dim: usize, - num_hidden_layers: usize, - num_attention_heads: usize, - max_position_embeddings: usize, - hidden_act: candle_nn::Activation, - layer_norm_eps: f64, - is_decoder: bool, + pub vocab_size: usize, + pub hidden_size: usize, + pub encoder_hidden_size: usize, + pub intermediate_size: usize, + pub projection_dim: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub max_position_embeddings: usize, + pub hidden_act: candle_nn::Activation, + pub layer_norm_eps: f64, + pub is_decoder: bool, } #[derive(Debug, Clone)] @@ -47,6 +47,17 @@ impl TextEmbeddings { } } +impl Module for TextEmbeddings { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let seq_len = xs.dim(1)?; + // Use past_key_values_length if we add a kv cache. + let position_ids = self.position_ids.narrow(1, 0, seq_len)?; + let embeddings = self.word_embedddings.forward(xs)?; + let position_embeddings = self.position_embeddings.forward(&position_ids)?; + (embeddings + position_embeddings)?.apply(&self.layer_norm) + } +} + #[derive(Debug, Clone)] struct TextSelfAttention { query: Linear, @@ -55,6 +66,7 @@ struct TextSelfAttention { all_head_size: usize, attention_head_size: usize, num_attention_heads: usize, + attention_scale: f64, } impl TextSelfAttention { @@ -70,6 +82,7 @@ impl TextSelfAttention { }; let key = linear(in_size, all_head_size, vb.pp("key"))?; let value = linear(in_size, all_head_size, vb.pp("value"))?; + let attention_scale = 1f64 / (attention_head_size as f64).sqrt(); Ok(Self { query, key, @@ -77,6 +90,7 @@ impl TextSelfAttention { all_head_size, attention_head_size, num_attention_heads, + attention_scale, }) } @@ -90,6 +104,35 @@ impl TextSelfAttention { ))? .permute((0, 2, 1, 3)) } + + fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> { + let query = self + .transpose_for_scores(&self.query.forward(xs)?)? + .contiguous()?; + let (key, value) = match encoder_hidden_states { + None => { + let key = self.transpose_for_scores(&self.key.forward(xs)?)?; + let value = self.transpose_for_scores(&self.value.forward(xs)?)?; + // TODO: kv cache + (key, value) + } + Some(xs) => { + let key = self.transpose_for_scores(&self.key.forward(xs)?)?; + let value = self.transpose_for_scores(&self.value.forward(xs)?)?; + // no kv-cache in this case, but the results could probably be memoized. + (key, value) + } + }; + let key = key.contiguous()?; + let value = value.contiguous()?; + let attention_scores = query.matmul(&key.t()?)?; + let attention_scores = (attention_scores * self.attention_scale)?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + attention_probs + .matmul(&value)? + .permute((0, 2, 1, 3))? + .flatten_from(D::Minus2) + } } #[derive(Debug, Clone)] @@ -122,6 +165,11 @@ impl TextAttention { let output = TextSelfOutput::new(cfg, vb.pp("output"))?; Ok(Self { self_, output }) } + + fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> { + let self_outputs = self.self_.forward(xs, encoder_hidden_states)?; + self.output.forward(&self_outputs, xs) + } } #[derive(Debug, Clone)] @@ -176,7 +224,7 @@ impl TextLayer { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let attention = TextAttention::new(cfg, false, vb.pp("attention"))?; let cross_attention = if cfg.is_decoder { - Some(TextAttention::new(cfg, true, vb.pp("attention"))?) + Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?) } else { None }; @@ -189,11 +237,15 @@ impl TextLayer { output, }) } -} -impl Module for TextLayer { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - todo!() + fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> { + let attention_output = self.attention.forward(xs, None)?; + let attention_output = match &self.cross_attention { + Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states))?, + None => candle::bail!("expected some cross-attn"), + }; + let intermediate_output = self.intermediate.forward(&attention_output)?; + self.output.forward(&intermediate_output, &attention_output) } } @@ -212,13 +264,11 @@ impl TextEncoder { } Ok(Self { layers }) } -} -impl Module for TextEncoder { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { + fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> { let mut xs = xs.clone(); for layer in self.layers.iter() { - xs = xs.apply(layer)? + xs = layer.forward(&xs, encoder_hidden_states)? } Ok(xs) } @@ -333,6 +383,15 @@ impl TextModel { pooler: None, }) } + + fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> { + let embedding_output = self.embeddings.forward(input_ids)?; + let sequence_output = self + .encoder + .forward(&embedding_output, encoder_hidden_states)?; + // We're interested in the sequence-output rather than the pooled-output. + Ok(sequence_output) + } } #[derive(Debug, Clone)] @@ -347,4 +406,11 @@ impl TextLMHeadModel { let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?; Ok(Self { bert, cls }) } + + pub fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> { + let sequence_output = self.bert.forward(input_ids, encoder_hidden_states)?; + let prediction_scores = self.cls.forward(&sequence_output)?; + // return_logits is false so we don't discard the last sequence element. + Ok(prediction_scores) + } } |