summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/blip/main.rs70
-rw-r--r--candle-transformers/src/models/mod.rs2
-rw-r--r--candle-transformers/src/models/quantized_blip.rs258
-rw-r--r--candle-transformers/src/models/quantized_blip_text.rs476
-rw-r--r--candle-transformers/src/quantized_nn.rs6
5 files changed, 795 insertions, 17 deletions
diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs
index 81c01482..45300feb 100644
--- a/candle-examples/examples/blip/main.rs
+++ b/candle-examples/examples/blip/main.rs
@@ -11,9 +11,25 @@ use candle::{DType, Device, Result, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::blip;
+use candle_transformers::models::quantized_blip;
use tokenizers::Tokenizer;
+enum Model {
+ M(blip::BlipForConditionalGeneration),
+ Q(quantized_blip::BlipForConditionalGeneration),
+}
+
+impl Model {
+ fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Self::M(m) => m.text_decoder().forward(xs, img_xs),
+ Self::Q(m) => m.text_decoder().forward(xs, img_xs),
+ }
+ }
+}
+
+// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
@@ -28,6 +44,10 @@ struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
+
+ /// Use the quantized version of the model.
+ #[arg(long)]
+ quantized: bool,
}
const SEP_TOKEN_ID: u32 = 102;
@@ -54,20 +74,20 @@ pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
- let device = candle_examples::device(args.cpu)?;
-
- let image = load_image(args.image)?.to_device(&device)?;
- println!("loaded image {image:?}");
-
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
- let api = api.repo(hf_hub::Repo::with_revision(
- "Salesforce/blip-image-captioning-large".to_string(),
- hf_hub::RepoType::Model,
- "refs/pr/18".to_string(),
- ));
- api.get("model.safetensors")?
+ if args.quantized {
+ let api = api.model("lmz/candle-blip".to_string());
+ api.get("blip-image-captioning-large-q4k.gguf")?
+ } else {
+ let api = api.repo(hf_hub::Repo::with_revision(
+ "Salesforce/blip-image-captioning-large".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/18".to_string(),
+ ));
+ api.get("model.safetensors")?
+ }
}
Some(model) => model.into(),
};
@@ -84,19 +104,35 @@ pub fn main() -> anyhow::Result<()> {
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
- let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let config = blip::Config::image_captioning_large();
- let mut model = blip::BlipForConditionalGeneration::new(&config, vb)?;
- println!("model built");
- // TODO: Maybe add support for the conditional prompt.
- let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
+
+ let (image_embeds, device, mut model) = if args.quantized {
+ let device = Device::Cpu;
+ let image = load_image(args.image)?.to_device(&device)?;
+ println!("loaded image {image:?}");
+
+ let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
+ let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
+ let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
+ (image_embeds, device, Model::Q(model))
+ } else {
+ let device = candle_examples::device(args.cpu)?;
+ let image = load_image(args.image)?.to_device(&device)?;
+ println!("loaded image {image:?}");
+
+ let vb =
+ unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
+ let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
+ let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
+ (image_embeds, device, Model::M(model))
+ };
let mut token_ids = vec![30522u32];
for index in 0..1000 {
let context_size = if index > 0 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
- let logits = model.text_decoder().forward(&input_ids, &image_embeds)?;
+ let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 6836b9c0..ce576c54 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -10,6 +10,8 @@ pub mod llama;
pub mod mistral;
pub mod mixformer;
pub mod mpt;
+pub mod quantized_blip;
+pub mod quantized_blip_text;
pub mod quantized_llama;
pub mod quantized_mistral;
pub mod quantized_mixformer;
diff --git a/candle-transformers/src/models/quantized_blip.rs b/candle-transformers/src/models/quantized_blip.rs
new file mode 100644
index 00000000..6c498aa0
--- /dev/null
+++ b/candle-transformers/src/models/quantized_blip.rs
@@ -0,0 +1,258 @@
+use super::quantized_blip_text as blip_text;
+use crate::quantized_nn::{layer_norm, linear, Linear};
+pub use crate::quantized_var_builder::VarBuilder;
+use candle::{Module, Result, Tensor, D};
+use candle_nn::{Conv2d, Conv2dConfig, LayerNorm};
+
+pub type VisionConfig = super::blip::VisionConfig;
+pub type Config = super::blip::Config;
+
+#[derive(Debug, Clone)]
+struct VisionEmbeddings {
+ class_embedding: Tensor,
+ patch_embedding: Conv2d,
+ position_embedding: Tensor,
+}
+
+impl VisionEmbeddings {
+ fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
+ let class_embedding = vb
+ .get((1, 1, cfg.hidden_size), "class_embedding")?
+ .dequantize(vb.device())?;
+ let conv_cfg = Conv2dConfig {
+ stride: cfg.patch_size,
+ ..Default::default()
+ };
+ let pe_vb = vb.pp("patch_embedding");
+ let pe_weight = pe_vb
+ .get(
+ (cfg.hidden_size, 3, cfg.patch_size, cfg.patch_size),
+ "weight",
+ )?
+ .dequantize(vb.device())?;
+ let pe_bias = pe_vb
+ .get(cfg.hidden_size, "bias")?
+ .dequantize(vb.device())?;
+
+ let patch_embedding = Conv2d::new(pe_weight, Some(pe_bias), conv_cfg);
+ let num_patches1 = cfg.image_size / cfg.patch_size;
+ let num_patches = num_patches1 * num_patches1;
+ let num_positions = num_patches + 1;
+ let position_embedding = vb
+ .get((1, num_positions, cfg.hidden_size), "position_embedding")?
+ .dequantize(vb.device())?;
+ Ok(Self {
+ class_embedding,
+ patch_embedding,
+ position_embedding,
+ })
+ }
+}
+
+impl Module for VisionEmbeddings {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let target_dtype = xs.dtype();
+ let b_size = xs.dim(0)?;
+ let patch_embeds = xs.apply(&self.patch_embedding)?.flatten_from(2)?.t()?;
+ let d = self.class_embedding.dim(D::Minus1)?;
+ let class_embeds = self
+ .class_embedding
+ .broadcast_as((b_size, 1, d))?
+ .to_dtype(target_dtype)?;
+ let embeddings = Tensor::cat(&[&class_embeds, &patch_embeds], 1)?;
+ let position_embedding = self.position_embedding.narrow(1, 0, embeddings.dim(1)?)?;
+ embeddings.broadcast_add(&position_embedding)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Attention {
+ qkv: Linear,
+ projection: Linear,
+ scale: f64,
+ num_heads: usize,
+}
+
+impl Attention {
+ fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
+ let embed_dim = cfg.hidden_size;
+ let num_heads = cfg.num_attention_heads;
+ let head_dim = embed_dim / num_heads;
+ let scale = 1f64 / (head_dim as f64).sqrt();
+ let qkv = linear(embed_dim, 3 * embed_dim, vb.pp("qkv"))?;
+ let projection = linear(embed_dim, embed_dim, vb.pp("projection"))?;
+ Ok(Self {
+ qkv,
+ projection,
+ scale,
+ num_heads,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result<Tensor> {
+ let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
+ let mixed_qkv = xs
+ .apply(&self.qkv)?
+ .reshape((b_sz, tgt_len, 3, self.num_heads, embed_dim / self.num_heads))?
+ .permute((2, 0, 3, 1, 4))?;
+ let query = mixed_qkv.get(0)?;
+ let key = mixed_qkv.get(1)?;
+ let value = mixed_qkv.get(2)?;
+ let attention_scores = query.matmul(&key.t()?)?;
+ let attention_scores = (attention_scores * self.scale)?;
+ let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
+ let attention_probs = match attn_mask {
+ None => attention_probs,
+ Some(attn_mask) => (attention_probs * attn_mask)?,
+ };
+ attention_probs
+ .matmul(&value)?
+ .permute((0, 2, 1, 3))?
+ .flatten_from(D::Minus2)?
+ .apply(&self.projection)
+ }
+}
+
+#[derive(Debug, Clone)]
+#[allow(clippy::upper_case_acronyms)]
+struct MLP {
+ activation_fn: candle_nn::Activation,
+ fc1: Linear,
+ fc2: Linear,
+}
+
+impl MLP {
+ fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
+ let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
+ let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
+ Ok(Self {
+ activation_fn: cfg.hidden_act,
+ fc1,
+ fc2,
+ })
+ }
+}
+
+impl Module for MLP {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.fc1)?
+ .apply(&self.activation_fn)?
+ .apply(&self.fc2)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct EncoderLayer {
+ self_attn: Attention,
+ layer_norm1: LayerNorm,
+ mlp: MLP,
+ layer_norm2: LayerNorm,
+}
+
+impl EncoderLayer {
+ fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
+ let embed_dim = cfg.hidden_size;
+ let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
+ let layer_norm1 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm1"))?;
+ let layer_norm2 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm2"))?;
+ let mlp = MLP::new(cfg, vb.pp("mlp"))?;
+ Ok(Self {
+ self_attn,
+ layer_norm1,
+ mlp,
+ layer_norm2,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
+ let residual = xs;
+ let xs = xs.apply(&self.layer_norm1)?;
+ let xs = self.self_attn.forward(&xs, attention_mask)?;
+ let xs = (xs + residual)?;
+
+ let residual = &xs;
+ let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Encoder {
+ layers: Vec<EncoderLayer>,
+}
+
+impl Encoder {
+ fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
+ let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb = vb.pp("layers");
+ for i in 0..cfg.num_hidden_layers {
+ let layer = EncoderLayer::new(cfg, vb.pp(i))?;
+ layers.push(layer)
+ }
+ Ok(Self { layers })
+ }
+
+ fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, attention_mask)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct VisionModel {
+ embeddings: VisionEmbeddings,
+ encoder: Encoder,
+ post_layernorm: LayerNorm,
+}
+
+impl VisionModel {
+ fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
+ let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
+ let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
+ let post_layernorm =
+ layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
+ Ok(Self {
+ embeddings,
+ encoder,
+ post_layernorm,
+ })
+ }
+}
+
+impl Module for VisionModel {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.apply(&self.embeddings)?;
+ let encoder_outputs = self.encoder.forward(&xs, None)?;
+ // Return the last hidden state rather than pooled outputs.
+ encoder_outputs.apply(&self.post_layernorm)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct BlipForConditionalGeneration {
+ vision_model: VisionModel,
+ text_decoder: blip_text::TextLMHeadModel,
+}
+
+impl BlipForConditionalGeneration {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
+ let text_decoder =
+ blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
+ Ok(Self {
+ vision_model,
+ text_decoder,
+ })
+ }
+
+ pub fn vision_model(&self) -> &VisionModel {
+ &self.vision_model
+ }
+
+ pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
+ &mut self.text_decoder
+ }
+}
diff --git a/candle-transformers/src/models/quantized_blip_text.rs b/candle-transformers/src/models/quantized_blip_text.rs
new file mode 100644
index 00000000..652205d6
--- /dev/null
+++ b/candle-transformers/src/models/quantized_blip_text.rs
@@ -0,0 +1,476 @@
+use crate::models::with_tracing::QMatMul;
+use crate::quantized_nn::{layer_norm, linear, Embedding, Linear};
+pub use crate::quantized_var_builder::VarBuilder;
+use candle::{Module, Result, Tensor, D};
+use candle_nn::LayerNorm;
+
+pub type Config = super::blip_text::Config;
+
+#[derive(Debug, Clone)]
+struct TextEmbeddings {
+ word_embedddings: Embedding,
+ position_embeddings: Embedding,
+ layer_norm: LayerNorm,
+ position_ids: Tensor,
+}
+
+impl TextEmbeddings {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let word_embedddings =
+ Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
+ let position_embeddings = Embedding::new(
+ cfg.max_position_embeddings,
+ cfg.hidden_size,
+ vb.pp("position_embeddings"),
+ )?;
+ let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
+ let position_ids =
+ Tensor::arange(0, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
+ Ok(Self {
+ word_embedddings,
+ position_embeddings,
+ layer_norm,
+ position_ids,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
+ let seq_len = xs.dim(1)?;
+ let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?;
+ let embeddings = self.word_embedddings.forward(xs)?;
+ let position_embeddings = self.position_embeddings.forward(&position_ids)?;
+ (embeddings + position_embeddings)?.apply(&self.layer_norm)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextSelfAttention {
+ query: Linear,
+ key: Linear,
+ value: Linear,
+ attention_head_size: usize,
+ num_attention_heads: usize,
+ attention_scale: f64,
+ kv_cache: Option<(Tensor, Tensor)>,
+}
+
+impl TextSelfAttention {
+ fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
+ let num_attention_heads = cfg.num_attention_heads;
+ let attention_head_size = cfg.hidden_size / num_attention_heads;
+ let all_head_size = cfg.num_attention_heads * attention_head_size;
+ let query = linear(cfg.hidden_size, all_head_size, vb.pp("query"))?;
+ let in_size = if is_cross_attention {
+ cfg.encoder_hidden_size
+ } else {
+ cfg.hidden_size
+ };
+ let key = linear(in_size, all_head_size, vb.pp("key"))?;
+ let value = linear(in_size, all_head_size, vb.pp("value"))?;
+ let attention_scale = 1f64 / (attention_head_size as f64).sqrt();
+ Ok(Self {
+ query,
+ key,
+ value,
+ attention_head_size,
+ num_attention_heads,
+ attention_scale,
+ kv_cache: None,
+ })
+ }
+
+ fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b_size, seq_len, _) = xs.dims3()?;
+ xs.reshape((
+ b_size,
+ seq_len,
+ self.num_attention_heads,
+ self.attention_head_size,
+ ))?
+ .permute((0, 2, 1, 3))
+ }
+
+ fn reset_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let query = self
+ .transpose_for_scores(&self.query.forward(xs)?)?
+ .contiguous()?;
+ let (key, value) = match encoder_hidden_states {
+ None => {
+ let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
+ let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
+ let (key, value) = match &self.kv_cache {
+ None => (key, value),
+ Some((prev_key, prev_value)) => {
+ let key = Tensor::cat(&[prev_key, &key], 2)?;
+ let value = Tensor::cat(&[prev_value, &value], 2)?;
+ (key, value)
+ }
+ };
+ self.kv_cache = Some((key.clone(), value.clone()));
+ (key, value)
+ }
+ Some(xs) => {
+ let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
+ let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
+ // no kv-cache in this case, but the results could probably be memoized.
+ (key, value)
+ }
+ };
+ let key = key.contiguous()?;
+ let value = value.contiguous()?;
+ let attention_scores = query.matmul(&key.t()?)?;
+ let attention_scores = (attention_scores * self.attention_scale)?;
+ let attention_scores = match attention_mask {
+ Some(mask) => attention_scores.broadcast_add(mask)?,
+ None => attention_scores,
+ };
+ let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
+ attention_probs
+ .matmul(&value)?
+ .permute((0, 2, 1, 3))?
+ .flatten_from(D::Minus2)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextSelfOutput {
+ dense: Linear,
+ layer_norm: LayerNorm,
+}
+
+impl TextSelfOutput {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
+ let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
+ Ok(Self { dense, layer_norm })
+ }
+
+ fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
+ (xs.apply(&self.dense) + input_tensor)?.apply(&self.layer_norm)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextAttention {
+ self_: TextSelfAttention,
+ output: TextSelfOutput,
+}
+
+impl TextAttention {
+ fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
+ let self_ = TextSelfAttention::new(cfg, is_cross_attention, vb.pp("self"))?;
+ let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
+ Ok(Self { self_, output })
+ }
+
+ fn reset_kv_cache(&mut self) {
+ self.self_.reset_kv_cache()
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let self_outputs = self
+ .self_
+ .forward(xs, encoder_hidden_states, attention_mask)?;
+ self.output.forward(&self_outputs, xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextIntermediate {
+ dense: Linear,
+ intermediate_act_fn: candle_nn::Activation,
+}
+
+impl TextIntermediate {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
+ Ok(Self {
+ dense,
+ intermediate_act_fn: cfg.hidden_act,
+ })
+ }
+}
+
+impl Module for TextIntermediate {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextOutput {
+ dense: Linear,
+ layer_norm: LayerNorm,
+}
+
+impl TextOutput {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
+ let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
+ Ok(Self { dense, layer_norm })
+ }
+
+ fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
+ (xs.apply(&self.dense)? + input_tensor)?.apply(&self.layer_norm)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextLayer {
+ attention: TextAttention,
+ cross_attention: Option<TextAttention>,
+ intermediate: TextIntermediate,
+ output: TextOutput,
+}
+
+impl TextLayer {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
+ let cross_attention = if cfg.is_decoder {
+ Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?)
+ } else {
+ None
+ };
+ let intermediate = TextIntermediate::new(cfg, vb.pp("intermediate"))?;
+ let output = TextOutput::new(cfg, vb.pp("output"))?;
+ Ok(Self {
+ attention,
+ cross_attention,
+ intermediate,
+ output,
+ })
+ }
+
+ fn reset_kv_cache(&mut self) {
+ self.attention.reset_kv_cache();
+ if let Some(ca) = &mut self.cross_attention {
+ ca.reset_kv_cache()
+ }
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ encoder_hidden_states: &Tensor,
+ attention_mask: &Tensor,
+ ) -> Result<Tensor> {
+ let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
+ let attention_output = match &mut self.cross_attention {
+ Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
+ None => candle::bail!("expected some cross-attn"),
+ };
+ let intermediate_output = self.intermediate.forward(&attention_output)?;
+ self.output.forward(&intermediate_output, &attention_output)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextEncoder {
+ layers: Vec<TextLayer>,
+}
+
+impl TextEncoder {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb = vb.pp("layer");
+ let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ for i in 0..cfg.num_hidden_layers {
+ let layer = TextLayer::new(cfg, vb.pp(i))?;
+ layers.push(layer)
+ }
+ Ok(Self { layers })
+ }
+
+ fn reset_kv_cache(&mut self) {
+ self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ encoder_hidden_states: &Tensor,
+ attention_mask: &Tensor,
+ ) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct TextPooler {
+ dense: Linear,
+}
+
+impl TextPooler {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
+ Ok(Self { dense })
+ }
+}
+
+impl Module for TextPooler {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.narrow(D::Minus1, 0, 1)?
+ .squeeze(D::Minus1)?
+ .apply(&self.dense)?
+ .tanh()
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextPredictionHeadTransform {
+ dense: Linear,
+ transform_act_fn: candle_nn::Activation,
+ layer_norm: LayerNorm,
+}
+
+impl TextPredictionHeadTransform {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
+ let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
+ Ok(Self {
+ dense,
+ transform_act_fn: cfg.hidden_act,
+ layer_norm,
+ })
+ }
+}
+
+impl Module for TextPredictionHeadTransform {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.dense)?
+ .apply(&self.transform_act_fn)?
+ .apply(&self.layer_norm)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextLMPredictionHead {
+ transform: TextPredictionHeadTransform,
+ decoder: Linear,
+}
+
+impl TextLMPredictionHead {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
+ let weight = QMatMul::new(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
+ let bias = vb.get(cfg.vocab_size, "bias")?.dequantize(vb.device())?;
+ let decoder = Linear::from_weights(weight, Some(bias));
+ Ok(Self { transform, decoder })
+ }
+}
+
+impl Module for TextLMPredictionHead {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.transform)?.apply(&self.decoder)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextOnlyMLMHead {
+ predictions: TextLMPredictionHead,
+}
+
+impl TextOnlyMLMHead {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let predictions = TextLMPredictionHead::new(cfg, vb.pp("predictions"))?;
+ Ok(Self { predictions })
+ }
+}
+
+impl Module for TextOnlyMLMHead {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.predictions.forward(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct TextModel {
+ embeddings: TextEmbeddings,
+ encoder: TextEncoder,
+ past_kv_len: usize,
+ // We do not need the pooler for caption generation
+}
+
+impl TextModel {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
+ let encoder = TextEncoder::new(cfg, vb.pp("encoder"))?;
+ Ok(Self {
+ embeddings,
+ encoder,
+ past_kv_len: 0,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ input_ids: &Tensor,
+ encoder_hidden_states: &Tensor,
+ attention_mask: &Tensor,
+ ) -> Result<Tensor> {
+ let (_b_sz, seq_len) = input_ids.dims2()?;
+ let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?;
+ let sequence_output =
+ self.encoder
+ .forward(&embedding_output, encoder_hidden_states, attention_mask)?;
+ self.past_kv_len += seq_len;
+ // We're interested in the sequence-output rather than the pooled-output.
+ Ok(sequence_output)
+ }
+
+ fn reset_kv_cache(&mut self) {
+ self.past_kv_len = 0;
+ self.encoder.reset_kv_cache();
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct TextLMHeadModel {
+ bert: TextModel,
+ cls: TextOnlyMLMHead,
+}
+
+impl TextLMHeadModel {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let bert = TextModel::new(cfg, vb.pp("bert"))?;
+ let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
+ Ok(Self { bert, cls })
+ }
+
+ pub fn forward(
+ &mut self,
+ input_ids: &Tensor,
+ encoder_hidden_states: &Tensor,
+ ) -> Result<Tensor> {
+ let seq_len = input_ids.dim(1)?;
+ let mask: Vec<_> = (0..seq_len)
+ .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
+ .collect();
+ let mask = Tensor::from_vec(mask, (seq_len, seq_len), input_ids.device())?;
+ let sequence_output = self.bert.forward(input_ids, encoder_hidden_states, &mask)?;
+ let prediction_scores = self.cls.forward(&sequence_output)?;
+ // return_logits is false so we don't discard the last sequence element.
+ Ok(prediction_scores)
+ }
+
+ pub fn reset_kv_cache(&mut self) {
+ self.bert.reset_kv_cache()
+ }
+}
diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs
index 2941c3f0..99e8d45b 100644
--- a/candle-transformers/src/quantized_nn.rs
+++ b/candle-transformers/src/quantized_nn.rs
@@ -34,6 +34,12 @@ pub struct Linear {
bias: Option<Tensor>,
}
+impl Linear {
+ pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self {
+ Self { weight, bias }
+ }
+}
+
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let x = x.apply(&self.weight)?;