summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/mixformer.rs20
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/moondream.rs308
3 files changed, 329 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs
index f7eb0abe..edca8b9d 100644
--- a/candle-transformers/src/models/mixformer.rs
+++ b/candle-transformers/src/models/mixformer.rs
@@ -438,6 +438,26 @@ impl MixFormerSequentialForCausalLM {
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
}
+ pub fn forward_with_img(&mut self, xs: &Tensor, img_embeds: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = xs.apply(&self.embedding)?;
+ let mut xs = Tensor::cat(&[img_embeds.clone(), xs], 1)?;
+ let (_b_size, seq_len, _embds) = xs.dims3()?;
+ let mask = if seq_len <= 1 {
+ None
+ } else {
+ Some(get_mask(seq_len, xs.device())?)
+ };
+ for block in self.blocks.iter_mut() {
+ xs = block.forward(&xs, mask.as_ref())?
+ }
+ let xs = xs
+ .narrow(1, seq_len - 1, 1)?
+ .apply(&self.head)?
+ .squeeze(1)?;
+ Ok(xs)
+ }
+
pub fn clear_kv_cache(&mut self) {
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 980ba535..ed0e0de7 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -24,6 +24,7 @@ pub mod mistral;
pub mod mixformer;
pub mod mixtral;
pub mod mobileone;
+pub mod moondream;
pub mod mpt;
pub mod persimmon;
pub mod phi;
diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs
new file mode 100644
index 00000000..1172bf71
--- /dev/null
+++ b/candle-transformers/src/models/moondream.rs
@@ -0,0 +1,308 @@
+use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
+use candle::{IndexOp, Result, Tensor, D};
+use candle_nn::{layer_norm, linear_b, Linear, Module, VarBuilder};
+
+pub struct Config {
+ pub phi_config: PhiConfig,
+ pub vision_config: VisionConfig,
+}
+
+impl Config {
+ pub fn v2() -> Self {
+ Self {
+ phi_config: PhiConfig::v1_5(),
+ vision_config: VisionConfig::v2(),
+ }
+ }
+}
+
+fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
+ let dim = q.dim(D::Minus1)?;
+ let scale_factor = 1.0 / (dim as f64).sqrt();
+ let k = k.transpose(D::Minus2, D::Minus1)?.contiguous()?;
+ let mut attn_weights = (q.contiguous()?.matmul(&k)? * scale_factor)?;
+ attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?.contiguous()?;
+ let attn_weights = attn_weights.matmul(&v.contiguous()?)?;
+ Ok(attn_weights)
+}
+
+#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
+pub struct VisionConfig {
+ image_embedding_dim: usize,
+ model_dim: usize,
+ hidden_dim: usize,
+ hidden_features: usize,
+ embed_len: usize,
+ embed_dim: usize,
+ num_blocks: usize,
+ num_heads: usize,
+ act: candle_nn::Activation,
+}
+
+impl VisionConfig {
+ pub fn v2() -> Self {
+ Self {
+ image_embedding_dim: 1152,
+ model_dim: 2048,
+ hidden_dim: 2048 * 4,
+ hidden_features: 4304,
+ embed_len: 729,
+ embed_dim: 1152,
+ num_blocks: 27,
+ num_heads: 16,
+ act: candle_nn::Activation::Gelu,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct LinearPatchEmbedding {
+ linear: Linear,
+}
+
+impl LinearPatchEmbedding {
+ fn new(vb: VarBuilder) -> Result<Self> {
+ let linear = linear_b(588, 1152, true, vb.pp("linear"))?;
+ Ok(Self { linear })
+ }
+}
+
+impl Module for LinearPatchEmbedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.linear)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Attention {
+ num_heads: usize,
+ head_dim: usize,
+ qkv: Linear,
+ proj: Linear,
+}
+
+impl Attention {
+ pub fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
+ let qkv = linear_b(dim, dim * 3, true, vb.pp("qkv"))?;
+ let proj = linear_b(dim, dim, true, vb.pp("proj"))?;
+ Ok(Self {
+ num_heads,
+ head_dim: dim / num_heads,
+ qkv,
+ proj,
+ })
+ }
+}
+
+impl Module for Attention {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b, n, c) = xs.dims3()?;
+ let qkv = xs
+ .apply(&self.qkv)?
+ .reshape((b, n, 3, self.num_heads, self.head_dim))?
+ .permute((2, 0, 3, 1, 4))?;
+ let (q, k, v) = (qkv.i(0)?, qkv.i(1)?, qkv.i(2)?);
+ let attn_weights = scaled_dot_product_attention(&q, &k, &v)?;
+ let attn_weights = attn_weights.transpose(1, 2)?.reshape((b, n, c))?;
+ attn_weights.apply(&self.proj)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct VitBlock {
+ attn: Attention,
+ mlp: Mlp,
+ norm1: candle_nn::LayerNorm,
+ norm2: candle_nn::LayerNorm,
+}
+
+impl VitBlock {
+ fn new(vb: VarBuilder, dim: usize, num_heads: usize, cfg: &VisionConfig) -> Result<Self> {
+ let attn = Attention::new(vb.pp("attn"), dim, num_heads)?;
+ let mlp = Mlp::new(vb.pp("mlp"), dim, cfg.hidden_features, dim, cfg.act)?;
+ let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
+ let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
+ Ok(Self {
+ attn,
+ mlp,
+ norm1,
+ norm2,
+ })
+ }
+}
+
+impl Module for VitBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
+ let xs = (xs + &ys)?;
+ let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
+ let xs = (&xs + &ys)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct VisionTransformer {
+ patch_embed: LinearPatchEmbedding,
+ pos_embed: Tensor,
+ blocks: Vec<VitBlock>,
+ norm: candle_nn::LayerNorm,
+}
+
+impl VisionTransformer {
+ fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
+ let patch_embed = LinearPatchEmbedding::new(vb.pp("patch_embed"))?;
+ let pos_embed = vb.get((1, cfg.embed_len, cfg.embed_dim), "pos_embed")?;
+ let blocks = (0..cfg.num_blocks)
+ .map(|i| {
+ VitBlock::new(
+ vb.pp(&format!("blocks.{}", i)),
+ cfg.embed_dim,
+ cfg.num_heads,
+ cfg,
+ )
+ })
+ .collect::<Result<_>>()?;
+ let norm = layer_norm(cfg.embed_dim, 1e-5, vb.pp("norm"))?;
+ Ok(Self {
+ patch_embed,
+ pos_embed,
+ blocks,
+ norm,
+ })
+ }
+}
+
+impl Module for VisionTransformer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
+ for block in self.blocks.iter() {
+ xs = xs.apply(block)?;
+ }
+ xs.apply(&self.norm)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Encoder {
+ model: VisionTransformer,
+}
+
+impl Encoder {
+ fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
+ let model = VisionTransformer::new(cfg, vb.pp("model.visual"))?;
+ Ok(Self { model })
+ }
+}
+
+impl Module for Encoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.model)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Mlp {
+ fc1: Linear,
+ act: candle_nn::Activation,
+ fc2: Linear,
+}
+
+impl Mlp {
+ fn new(
+ vb: VarBuilder,
+ in_features: usize,
+ hidden_features: usize,
+ out_features: usize,
+ act: candle_nn::Activation,
+ ) -> Result<Self> {
+ let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
+ let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
+ Ok(Self { fc1, act, fc2 })
+ }
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct VisionProjection {
+ mlp: Mlp,
+}
+
+impl VisionProjection {
+ fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
+ let mlp = Mlp::new(
+ vb.pp("mlp"),
+ cfg.image_embedding_dim,
+ cfg.hidden_dim,
+ cfg.model_dim,
+ cfg.act,
+ )?;
+ Ok(Self { mlp })
+ }
+}
+
+impl Module for VisionProjection {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.mlp)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct VisionEncoder {
+ encoder: Encoder,
+ projection: VisionProjection,
+}
+
+impl VisionEncoder {
+ pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
+ let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
+ let projection = VisionProjection::new(cfg, vb.pp("projection"))?;
+ Ok(Self {
+ encoder,
+ projection,
+ })
+ }
+}
+
+impl Module for VisionEncoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b, c, hp1, wp2) = xs.dims4()?;
+ let (p1, p2) = (14, 14);
+ let h = hp1 / p1;
+ let w = wp2 / p2;
+ let xs = xs
+ .reshape((b, c, h, p1, h, p2))?
+ .permute((0, 2, 4, 1, 3, 5))?
+ .reshape((b, h * w, c * p1 * p2))?;
+ xs.apply(&self.encoder)?.apply(&self.projection)
+ }
+}
+
+pub struct Model {
+ pub text_model: PhiModel,
+ pub vision_encoder: VisionEncoder,
+}
+
+impl Model {
+ pub fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
+ let text_model = PhiModel::new_v2(&config.phi_config, vb.pp("text_model"))?;
+ let vision_encoder = VisionEncoder::new(&config.vision_config, vb.pp("vision_encoder"))?;
+ Ok(Self {
+ text_model,
+ vision_encoder,
+ })
+ }
+
+ pub fn vision_encoder(&self) -> &VisionEncoder {
+ &self.vision_encoder
+ }
+
+ pub fn text_model(&mut self) -> &mut PhiModel {
+ &mut self.text_model
+ }
+}