diff options
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/dinov2.rs | 279 | ||||
-rw-r--r-- | candle-transformers/src/models/efficientnet.rs | 331 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_llama.rs | 371 | ||||
-rw-r--r-- | candle-transformers/src/models/segment_anything/image_encoder.rs | 483 | ||||
-rw-r--r-- | candle-transformers/src/models/segment_anything/mask_decoder.rs | 239 | ||||
-rw-r--r-- | candle-transformers/src/models/segment_anything/mod.rs | 100 | ||||
-rw-r--r-- | candle-transformers/src/models/segment_anything/prompt_encoder.rs | 239 | ||||
-rw-r--r-- | candle-transformers/src/models/segment_anything/sam.rs | 411 | ||||
-rw-r--r-- | candle-transformers/src/models/segment_anything/tiny_vit.rs | 633 | ||||
-rw-r--r-- | candle-transformers/src/models/segment_anything/transformer.rs | 221 | ||||
-rw-r--r-- | candle-transformers/src/object_detection.rs | 52 |
13 files changed, 3364 insertions, 0 deletions
diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs index a8890dc8..b83e5056 100644 --- a/candle-transformers/src/lib.rs +++ b/candle-transformers/src/lib.rs @@ -1,4 +1,5 @@ pub mod generation; pub mod models; +pub mod object_detection; pub mod pipelines; pub mod utils; diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs new file mode 100644 index 00000000..0edc8494 --- /dev/null +++ b/candle-transformers/src/models/dinov2.rs @@ -0,0 +1,279 @@ +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +const IMG_SIZE: usize = 518; +const PATCH_SIZE: usize = 14; +const NUM_CLASSES: usize = 1000; + +fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { + if bias { + candle_nn::linear(in_dim, out_dim, vb) + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb) + } +} + +#[derive(Debug)] +struct Attention { + qkv: Linear, + proj: Linear, + num_heads: usize, + scale: f64, +} + +impl Attention { + fn new( + vb: VarBuilder, + dim: usize, + num_heads: usize, + qkv_bias: bool, + proj_bias: bool, + ) -> Result<Self> { + let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; + let scale = 1. / ((dim / num_heads) as f64).sqrt(); + Ok(Self { + qkv, + proj, + num_heads, + scale, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (b, n, c) = xs.dims3()?; + let qkv = self + .qkv + .forward(xs)? + .reshape((b, n, 3, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? // 02134 + .transpose(0, 1)? // 20134 + .transpose(2, 3)?; // 20314 + let q = (qkv.i(0)? * self.scale)?; + let k = qkv.i(1)?; + let v = qkv.i(2)?; + let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; + let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; + self.proj.forward(&attn) + } +} + +#[derive(Debug)] +struct LayerScale { + gamma: Tensor, +} + +impl LayerScale { + fn new(vb: VarBuilder, dim: usize) -> Result<Self> { + let gamma = vb.get(dim, "gamma")?; + Ok(Self { gamma }) + } +} + +impl Module for LayerScale { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.broadcast_mul(&self.gamma) + } +} + +#[derive(Debug)] +struct Mlp { + fc1: Linear, + fc2: Linear, +} + +impl Mlp { + fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> { + let out_features = in_features; + let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?; + let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.fc1.forward(xs)?.gelu()?; + self.fc2.forward(&xs) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + ls1: LayerScale, + norm2: LayerNorm, + mlp: Mlp, + ls2: LayerScale, +} + +impl Block { + fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> { + let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; + let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; + let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; + let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; + let ls2 = LayerScale::new(vb.pp("ls2"), dim)?; + Ok(Self { + norm1, + attn, + ls1, + norm2, + mlp, + ls2, + }) + } +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + let xs = self + .ls1 + .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = self + .ls2 + .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?; + xs + residual + } +} + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, + patch_size: (usize, usize), + num_patches: usize, +} + +impl PatchEmbed { + fn new( + vb: VarBuilder, + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + ) -> Result<Self> { + let config = candle_nn::Conv2dConfig { + stride: patch_size, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?; + let num_patches = (img_size / patch_size) * (img_size / patch_size); + Ok(Self { + proj, + patch_size: (patch_size, patch_size), + num_patches, + }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (_b, _c, h, w) = xs.dims4()?; + let (patch_h, patch_w) = self.patch_size; + if (h % patch_h) != 0 { + candle::bail!("image height {h} is not a multiple of patch height {patch_h}") + } + if (w % patch_w) != 0 { + candle::bail!("image width {w} is not a multiple of patch width {patch_w}") + } + let xs = self.proj.forward(xs)?; + let (b, c, h, w) = xs.dims4()?; + // flatten embeddings. + xs.reshape((b, c, h * w))?.transpose(1, 2) + } +} + +#[derive(Debug)] +pub struct DinoVisionTransformer { + patch_embed: PatchEmbed, + cls_token: Tensor, + pos_embed: Tensor, + blocks: Vec<Block>, + norm: LayerNorm, + head: Linear, +} + +impl DinoVisionTransformer { + pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> { + let patch_embed = + PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?; + let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; + let num_tokens = 1; + let pos_embed = vb.get( + (1, patch_embed.num_patches + num_tokens, embed_dim), + "pos_embed", + )?; + let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?; + let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?; + let vb_b = vb.pp("blocks"); + let blocks = (0..depth) + .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .collect::<Result<Vec<_>>>()?; + Ok(Self { + patch_embed, + cls_token, + pos_embed, + blocks, + norm, + head, + }) + } + + fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> { + let npatch = xs.dim(1)? - 1; + let n = self.pos_embed.dim(1)? - 1; + let sqrt_n = (n as f64).sqrt(); + if npatch == n && w == h { + return Ok(xs.clone()); + } + let class_pos_embed = self.pos_embed.i((.., ..1))?; + let patch_pos_embed = self.pos_embed.i((.., 1..))?; + let dim = xs.dim(D::Minus1)?; + let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); + let patch_pos_embed = patch_pos_embed + .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? + .transpose(2, 3)? + .transpose(1, 2)?; + // This uses bicubic interpolation in the original implementation. + let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; + let el_count = patch_pos_embed.shape().elem_count(); + let patch_pos_embed = + patch_pos_embed + .transpose(1, 2)? + .transpose(2, 3)? + .reshape((1, el_count / dim, dim))?; + Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1) + } + + fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> { + let (_b, _nc, w, h) = xs.dims4()?; + let xs = self.patch_embed.forward(xs)?; + let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; + &xs + &self.interpolate_pos_encoding(&xs, w, h)? + } +} + +impl Module for DinoVisionTransformer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.prepare_tokens_with_mask(xs)?; + for blk in self.blocks.iter() { + xs = blk.forward(&xs)? + } + let xs = self.norm.forward(&xs)?; + let xs_norm_clstoken = xs.i((.., 0))?; + let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?; + let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?; + self.head.forward(&xs) + } +} + +pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> { + DinoVisionTransformer::new(vb, 12, 384, 6) +} diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs new file mode 100644 index 00000000..ab51c76d --- /dev/null +++ b/candle-transformers/src/models/efficientnet.rs @@ -0,0 +1,331 @@ +use candle::{Result, Tensor, D}; +use candle_nn as nn; +use nn::{Module, VarBuilder}; + +// Based on the Python version from torchvision. +// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47 +#[derive(Debug, Clone, Copy)] +pub struct MBConvConfig { + expand_ratio: f64, + kernel: usize, + stride: usize, + input_channels: usize, + out_channels: usize, + num_layers: usize, +} + +fn make_divisible(v: f64, divisor: usize) -> usize { + let min_value = divisor; + let new_v = usize::max( + min_value, + (v + divisor as f64 * 0.5) as usize / divisor * divisor, + ); + if (new_v as f64) < 0.9 * v { + new_v + divisor + } else { + new_v + } +} + +fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> { + let bneck_conf = |e, k, s, i, o, n| { + let input_channels = make_divisible(i as f64 * width_mult, 8); + let out_channels = make_divisible(o as f64 * width_mult, 8); + let num_layers = (n as f64 * depth_mult).ceil() as usize; + MBConvConfig { + expand_ratio: e, + kernel: k, + stride: s, + input_channels, + out_channels, + num_layers, + } + }; + vec![ + bneck_conf(1., 3, 1, 32, 16, 1), + bneck_conf(6., 3, 2, 16, 24, 2), + bneck_conf(6., 5, 2, 24, 40, 2), + bneck_conf(6., 3, 2, 40, 80, 3), + bneck_conf(6., 5, 1, 80, 112, 3), + bneck_conf(6., 5, 2, 112, 192, 4), + bneck_conf(6., 3, 1, 192, 320, 1), + ] +} + +impl MBConvConfig { + pub fn b0() -> Vec<Self> { + bneck_confs(1.0, 1.0) + } + pub fn b1() -> Vec<Self> { + bneck_confs(1.0, 1.1) + } + pub fn b2() -> Vec<Self> { + bneck_confs(1.1, 1.2) + } + pub fn b3() -> Vec<Self> { + bneck_confs(1.2, 1.4) + } + pub fn b4() -> Vec<Self> { + bneck_confs(1.4, 1.8) + } + pub fn b5() -> Vec<Self> { + bneck_confs(1.6, 2.2) + } + pub fn b6() -> Vec<Self> { + bneck_confs(1.8, 2.6) + } + pub fn b7() -> Vec<Self> { + bneck_confs(2.0, 3.1) + } +} + +/// Conv2D with same padding. +#[derive(Debug)] +struct Conv2DSame { + conv2d: nn::Conv2d, + s: usize, + k: usize, +} + +impl Conv2DSame { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + bias: bool, + ) -> Result<Self> { + let conv_config = nn::Conv2dConfig { + stride, + groups, + ..Default::default() + }; + let conv2d = if bias { + nn::conv2d(i, o, k, conv_config, vb)? + } else { + nn::conv2d_no_bias(i, o, k, conv_config, vb)? + }; + Ok(Self { + conv2d, + s: stride, + k, + }) + } +} + +impl Module for Conv2DSame { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let s = self.s; + let k = self.k; + let (_, _, ih, iw) = xs.dims4()?; + let oh = (ih + s - 1) / s; + let ow = (iw + s - 1) / s; + let pad_h = usize::max((oh - 1) * s + k - ih, 0); + let pad_w = usize::max((ow - 1) * s + k - iw, 0); + if pad_h > 0 || pad_w > 0 { + let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?; + let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?; + self.conv2d.forward(&xs) + } else { + self.conv2d.forward(xs) + } + } +} + +#[derive(Debug)] +struct ConvNormActivation { + conv2d: Conv2DSame, + bn2d: nn::BatchNorm, + activation: bool, +} + +impl ConvNormActivation { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + ) -> Result<Self> { + let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?; + let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?; + Ok(Self { + conv2d, + bn2d, + activation: true, + }) + } + + fn no_activation(self) -> Self { + Self { + activation: false, + ..self + } + } +} + +impl Module for ConvNormActivation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.conv2d.forward(xs)?; + let xs = self.bn2d.forward(&xs)?; + if self.activation { + swish(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +struct SqueezeExcitation { + fc1: Conv2DSame, + fc2: Conv2DSame, +} + +impl SqueezeExcitation { + fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> { + let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?; + let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for SqueezeExcitation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + // equivalent to adaptive_avg_pool2d([1, 1]) + let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; + let xs = self.fc1.forward(&xs)?; + let xs = swish(&xs)?; + let xs = self.fc2.forward(&xs)?; + let xs = nn::ops::sigmoid(&xs)?; + residual.broadcast_mul(&xs) + } +} + +#[derive(Debug)] +struct MBConv { + expand_cna: Option<ConvNormActivation>, + depthwise_cna: ConvNormActivation, + squeeze_excitation: SqueezeExcitation, + project_cna: ConvNormActivation, + config: MBConvConfig, +} + +impl MBConv { + fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> { + let vb = vb.pp("block"); + let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8); + let expand_cna = if exp != c.input_channels { + Some(ConvNormActivation::new( + vb.pp("0"), + c.input_channels, + exp, + 1, + 1, + 1, + )?) + } else { + None + }; + let start_index = if expand_cna.is_some() { 1 } else { 0 }; + let depthwise_cna = + ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?; + let squeeze_channels = usize::max(1, c.input_channels / 4); + let squeeze_excitation = + SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?; + let project_cna = + ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)? + .no_activation(); + Ok(Self { + expand_cna, + depthwise_cna, + squeeze_excitation, + project_cna, + config: c, + }) + } +} + +impl Module for MBConv { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let use_res_connect = + self.config.stride == 1 && self.config.input_channels == self.config.out_channels; + let ys = match &self.expand_cna { + Some(expand_cna) => expand_cna.forward(xs)?, + None => xs.clone(), + }; + let ys = self.depthwise_cna.forward(&ys)?; + let ys = self.squeeze_excitation.forward(&ys)?; + let ys = self.project_cna.forward(&ys)?; + if use_res_connect { + ys + xs + } else { + Ok(ys) + } + } +} + +fn swish(s: &Tensor) -> Result<Tensor> { + s * nn::ops::sigmoid(s)? +} + +#[derive(Debug)] +pub struct EfficientNet { + init_cna: ConvNormActivation, + blocks: Vec<MBConv>, + final_cna: ConvNormActivation, + classifier: nn::Linear, +} + +impl EfficientNet { + pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> { + let f_p = p.pp("features"); + let first_in_c = configs[0].input_channels; + let last_out_c = configs.last().unwrap().out_channels; + let final_out_c = 4 * last_out_c; + let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; + let nconfigs = configs.len(); + let mut blocks = vec![]; + for (index, cnf) in configs.into_iter().enumerate() { + let f_p = f_p.pp(index + 1); + for r_index in 0..cnf.num_layers { + let cnf = if r_index == 0 { + cnf + } else { + MBConvConfig { + input_channels: cnf.out_channels, + stride: 1, + ..cnf + } + }; + blocks.push(MBConv::new(f_p.pp(r_index), cnf)?) + } + } + let final_cna = + ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?; + let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?; + Ok(Self { + init_cna, + blocks, + final_cna, + classifier, + }) + } +} + +impl Module for EfficientNet { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.init_cna.forward(xs)?; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + let xs = self.final_cna.forward(&xs)?; + // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1) + let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?; + self.classifier.forward(&xs) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 1b3dcf25..76e13b2a 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1,5 +1,9 @@ pub mod bert; pub mod bigcode; +pub mod dinov2; +pub mod efficientnet; pub mod falcon; pub mod llama; +pub mod quantized_llama; +pub mod segment_anything; pub mod whisper; diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs new file mode 100644 index 00000000..da0bd0b0 --- /dev/null +++ b/candle-transformers/src/models/quantized_llama.rs @@ -0,0 +1,371 @@ +use std::collections::HashMap; + +use candle::quantized::QTensor; +use candle::quantized::{ggml_file, gguf_file}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module}; + +pub const MAX_SEQ_LEN: usize = 4096; + +struct RmsNorm { + inner: candle_nn::LayerNorm, + span: tracing::Span, +} + +impl RmsNorm { + fn new(scale: QTensor, eps: f32) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let scale = scale.dequantize(&Device::Cpu)?; + let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); + Ok(Self { inner, span }) + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +// QMatMul wrapper adding some tracing. +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn from_qtensor(qtensor: QTensor) -> Self { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor); + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Self { inner, span } + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +struct LayerWeights { + attention_wq: QMatMul, + attention_wk: QMatMul, + attention_wv: QMatMul, + attention_wo: QMatMul, + attention_norm: RmsNorm, + feed_forward_w1: QMatMul, + feed_forward_w2: QMatMul, + feed_forward_w3: QMatMul, + ffn_norm: RmsNorm, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, + kv_cache: Option<(Tensor, Tensor)>, + span_attn: tracing::Span, + span_rot: tracing::Span, + span_mlp: tracing::Span, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_rot.enter(); + let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; + let cos = self + .cos + .narrow(0, index_pos, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let sin = self + .sin + .narrow(0, index_pos, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + // This mimics the llama.cpp behavior. + // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 + // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. + // The resulting y0 and y1 are also interleaved with: + // y0 = x0*cos - x1*sin + // y1 = x0*sin + x1*cos + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[y0, y1], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; + Ok(rope) + } + + fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_attn.enter(); + let (b_sz, seq_len, n_embd) = x.dims3()?; + let q = self.attention_wq.forward(x)?; + let k = self.attention_wk.forward(x)?; + let v = self.attention_wv.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let k = self.apply_rotary_emb(&k, index_pos)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + if index_pos == 0 { + (k, v) + } else { + let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; + let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; + (k, v) + } + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + // Support for MQA, useful for 70B models. + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = mask.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attention_wo.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { + let n_rep = self.n_head / self.n_kv_head; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; + let x = x + .unsqueeze(2)? + .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? + .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; + Ok(x) + } + } +} + +pub struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec<LayerWeights>, + norm: RmsNorm, + output: QMatMul, + masks: HashMap<usize, Tensor>, + span: tracing::Span, + span_output: tracing::Span, +} + +fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +impl ModelWeights { + pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { + let cpu = &Device::Cpu; + let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; + let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; + let tok_embeddings = ct.remove("tok_embeddings.weight")?; + let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; + let output = ct.remove("output.weight")?; + let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); + for layer_idx in 0..ct.hparams.n_layer { + let prefix = format!("layers.{layer_idx}"); + let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; + let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; + let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; + let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; + let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; + let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq), + attention_wk: QMatMul::from_qtensor(attention_wk), + attention_wv: QMatMul::from_qtensor(attention_wv), + attention_wo: QMatMul::from_qtensor(attention_wo), + attention_norm: RmsNorm::new(attention_norm, 1e-5)?, + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, + n_head: ct.hparams.n_head as usize, + n_kv_head: ct.hparams.n_head as usize / gqa, + head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, + cos: cos.clone(), + sin: sin.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), + layers, + norm, + output: QMatMul::from_qtensor(output), + masks: HashMap::new(), + span, + span_output, + }) + } + + pub fn from_gguf<R: std::io::Seek + std::io::Read>( + ct: gguf_file::Content, + reader: &mut R, + ) -> Result<Self> { + let cpu = &Device::Cpu; + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + // Parameter extraction from metadata. + let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("llama.block_count")?.to_u32()? as usize; + let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. + let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; + + let rope_freq_base = md_get("llama.rope.freq_base") + .and_then(|m| m.to_f32()) + .unwrap_or(10000f32); + let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; + + let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; + let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; + let output = ct.tensor(reader, "output.weight")?; + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; + let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; + let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; + let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; + let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq), + attention_wk: QMatMul::from_qtensor(attention_wk), + attention_wv: QMatMul::from_qtensor(attention_wv), + attention_wo: QMatMul::from_qtensor(attention_wo), + attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: embedding_length / head_count, + cos: cos.clone(), + sin: sin.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + norm, + output: QMatMul::from_qtensor(output), + masks: HashMap::new(), + span, + span_output, + }) + } + + fn mask(&mut self, t: usize) -> Result<Tensor> { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let (_b_sz, seq_len) = x.dims2()?; + let mask = self.mask(seq_len)?; + let _enter = self.span.enter(); + let mut layer_in = self.tok_embeddings.forward(x)?; + for layer in self.layers.iter_mut() { + let x = layer_in; + let residual = &x; + let x = layer.attention_norm.forward(&x)?; + let attn = layer.forward_attn(&x, &mask, index_pos)?; + let x = (attn + residual)?; + + // MLP + let _enter = layer.span_mlp.enter(); + let residual = &x; + let x = layer.ffn_norm.forward(&x)?; + let w1 = layer.feed_forward_w1.forward(&x)?; + let w3 = layer.feed_forward_w3.forward(&x)?; + let mlp = layer + .feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; + layer_in = (mlp + residual)?; + } + let x = self.norm.forward(&layer_in)?; + let x = x.i((.., seq_len - 1, ..))?; + let _enter = self.span_output.enter(); + self.output.forward(&x) + } +} diff --git a/candle-transformers/src/models/segment_anything/image_encoder.rs b/candle-transformers/src/models/segment_anything/image_encoder.rs new file mode 100644 index 00000000..0b313830 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/image_encoder.rs @@ -0,0 +1,483 @@ +use candle::{DType, IndexOp, Result, Tensor}; +use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder}; + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, + span: tracing::Span, +} + +impl PatchEmbed { + fn new( + in_chans: usize, + embed_dim: usize, + k_size: usize, + stride: usize, + padding: usize, + vb: VarBuilder, + ) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + stride, + padding, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; + let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); + Ok(Self { proj, span }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.proj)?.permute((0, 2, 3, 1)) + } +} + +// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final +// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096 +// (attn.reshape((b, q_h, q_w, k_h, k_w))? +// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? +// .reshape((b, q_h * q_w, k_h * k_w)) +// Ideally we would perform this operation in place but this is not supported in candle at the +// moment. We should also investigate using f16 rather than f32. +struct Add3(usize, usize, usize, usize, usize); +impl candle::CustomOp3 for Add3 { + fn name(&self) -> &'static str { + "add3" + } + + fn cpu_fwd( + &self, + s1: &candle::CpuStorage, + l1: &candle::Layout, + s2: &candle::CpuStorage, + l2: &candle::Layout, + s3: &candle::CpuStorage, + l3: &candle::Layout, + ) -> Result<(candle::CpuStorage, candle::Shape)> { + use rayon::prelude::*; + + let Add3(b, q_h, q_w, k_h, k_w) = *self; + let s1 = s1.as_slice::<f32>()?; + let s1 = match l1.contiguous_offsets() { + None => candle::bail!("input1 has to be contiguous"), + Some((o1, o2)) => &s1[o1..o2], + }; + let s2 = s2.as_slice::<f32>()?; + let s2 = match l2.contiguous_offsets() { + None => candle::bail!("input2 has to be contiguous"), + Some((o1, o2)) => &s2[o1..o2], + }; + let s3 = s3.as_slice::<f32>()?; + let s3 = match l3.contiguous_offsets() { + None => candle::bail!("input3 has to be contiguous"), + Some((o1, o2)) => &s3[o1..o2], + }; + let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w]; + dst.par_chunks_exact_mut(k_h * k_w) + .enumerate() + .for_each(|(b_idx, dst)| { + let s1_idx = b_idx * k_h * k_w; + let s2_idx = b_idx * k_h; + let s3_idx = b_idx * k_w; + for h_idx in 0..k_h { + let s1_idx = s1_idx + h_idx * k_w; + let s2_idx = s2_idx + h_idx; + let dst_idx = h_idx * k_w; + for w_idx in 0..k_w { + let s1_idx = s1_idx + w_idx; + let s3_idx = s3_idx + w_idx; + let dst_idx = dst_idx + w_idx; + dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx] + } + } + }); + let dst = candle::WithDType::to_cpu_storage_owned(dst); + Ok((dst, (b, q_h * q_w, k_h * k_w).into())) + } +} + +#[derive(Debug)] +struct Attention { + qkv: super::Linear, + proj: super::Linear, + num_heads: usize, + scale: f64, + rel_pos_hw: Option<(Tensor, Tensor)>, + span: tracing::Span, + span_matmul: tracing::Span, + span_rel_pos: tracing::Span, + span_softmax: tracing::Span, +} + +impl Attention { + fn new( + dim: usize, + num_heads: usize, + qkv_bias: bool, + use_rel_pos: bool, + input_size: (usize, usize), + vb: VarBuilder, + ) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "attention"); + let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); + let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); + let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = super::linear(vb.pp("proj"), dim, dim, true)?; + let head_dim = dim / num_heads; + let scale = 1. / (head_dim as f64).sqrt(); + let rel_pos_hw = if use_rel_pos { + let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?; + let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?; + Some((h, w)) + } else { + None + }; + Ok(Self { + qkv, + proj, + num_heads, + scale, + rel_pos_hw, + span, + span_matmul, + span_rel_pos, + span_softmax, + }) + } + + fn add_decomposed_rel_pos( + &self, + attn: Tensor, + q: &Tensor, + (q_h, q_w): (usize, usize), + (k_h, k_w): (usize, usize), + ) -> Result<Tensor> { + match &self.rel_pos_hw { + Some((rel_pos_h, rel_pos_w)) => { + let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?; + let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?; + let (b, _, dim) = q.dims3()?; + let r_q = q.reshape((b, q_h, q_w, dim))?; + // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?; + // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + let rel_w = r_q + .transpose(1, 2)? // -> bwhc + .contiguous()? + .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk + .transpose(1, 2)? + .contiguous()?; + if attn.device().is_cpu() { + let op = Add3(b, q_h, q_w, k_h, k_w); + attn.apply_op3_no_bwd(&rel_h, &rel_w, &op) + } else { + (attn.reshape((b, q_h, q_w, k_h, k_w))? + + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? + .reshape((b, q_h * q_w, k_h * k_w)) + } + } + None => Ok(attn), + } + } +} + +fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> { + let max_rel_dist = 2 * usize::max(q_size, k_size) - 1; + let dev = rel_pos.device(); + let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist { + todo!("interpolation") + } else { + rel_pos + }; + let q_coords = Tensor::arange(0u32, q_size as u32, dev)? + .reshape((q_size, 1))? + .to_dtype(DType::F32)?; + let k_coords = Tensor::arange(0u32, k_size as u32, dev)? + .reshape((1, k_size))? + .to_dtype(DType::F32)?; + let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?; + let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?; + let relative_coords = (q_coords.broadcast_sub(&k_coords)? + + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?; + let (d1, d2) = relative_coords.dims2()?; + let relative_coords = relative_coords.to_dtype(DType::U32)?; + rel_pos_resized + .index_select(&relative_coords.reshape(d1 * d2)?, 0)? + .reshape((d1, d2, ())) +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (b, h, w, c) = xs.dims4()?; + let qkv = self + .qkv + .forward(&xs.flatten_to(1)?)? + .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))? + .permute((2, 0, 3, 1, 4))? + .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?; + let q = qkv.i(0)?; + let k = qkv.i(1)?; + let v = qkv.i(2)?; + let attn = { + let _enter = self.span_matmul.enter(); + (&q * self.scale)?.matmul(&k.t()?)? + }; + let attn = { + let _enter = self.span_rel_pos.enter(); + self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))? + }; + let attn = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn)? + }; + let attn = { + let _enter = self.span_matmul.enter(); + attn.matmul(&v)? + }; + let attn = attn + .reshape((b, self.num_heads, h, w, c / self.num_heads))? + .permute((0, 2, 3, 1, 4))? + .reshape((b, h * w, c))?; + self.proj.forward(&attn)?.reshape((b, h, w, c)) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + norm2: LayerNorm, + mlp: super::MlpBlock, + window_size: usize, + span: tracing::Span, +} + +impl Block { + fn new( + dim: usize, + num_heads: usize, + qkv_bias: bool, + use_rel_pos: bool, + window_size: usize, + input_size: (usize, usize), + vb: VarBuilder, + ) -> Result<Self> { + let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; + let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; + let input_size_attn = if window_size == 0 { + input_size + } else { + (window_size, window_size) + }; + let attn = Attention::new( + dim, + num_heads, + qkv_bias, + use_rel_pos, + input_size_attn, + vb.pp("attn"), + )?; + let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; + let span = tracing::span!(tracing::Level::TRACE, "ie-block"); + Ok(Self { + norm1, + attn, + norm2, + mlp, + window_size, + span, + }) + } +} + +fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> { + let (b, h, w, c) = xs.dims4()?; + let pad_h = (window_size - h % window_size) % window_size; + let pad_w = (window_size - w % window_size) % window_size; + let xs = if pad_h > 0 { + xs.pad_with_zeros(1, 0, pad_h)? + } else { + xs + }; + let xs = if pad_w > 0 { + xs.pad_with_zeros(2, 0, pad_w)? + } else { + xs + }; + let (h_p, w_p) = (h + pad_h, w + pad_w); + let windows = xs + .reshape(( + b, + h_p / window_size, + window_size, + w_p / window_size, + window_size, + c, + ))? + .transpose(2, 3)? + .contiguous()? + .flatten_to(2)?; + Ok((windows, (h_p, w_p))) +} + +fn window_unpartition( + windows: Tensor, + window_size: usize, + (h_p, w_p): (usize, usize), + (h, w): (usize, usize), +) -> Result<Tensor> { + let b = windows.dim(0)? / (h_p * w_p / window_size / window_size); + let xs = windows + .reshape(( + b, + h_p / window_size, + w_p / window_size, + window_size, + window_size, + windows.elem_count() / b / h_p / w_p, + ))? + .transpose(2, 3)? + .contiguous()? + .reshape((b, h_p, w_p, ()))?; + let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs }; + let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs }; + Ok(xs) +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let shortcut = xs; + let xs = self.norm1.forward(xs)?; + let hw = (xs.dim(1)?, xs.dim(2)?); + let (xs, pad_hw) = if self.window_size > 0 { + window_partition(xs, self.window_size)? + } else { + (xs, (0, 0)) + }; + let xs = self.attn.forward(&xs)?; + let xs = if self.window_size > 0 { + window_unpartition(xs, self.window_size, pad_hw, hw)? + } else { + xs + }; + let xs = (xs + shortcut)?; + &xs + xs.apply(&self.norm2)?.apply(&self.mlp)? + } +} + +#[derive(Debug)] +pub struct ImageEncoderViT { + patch_embed: PatchEmbed, + blocks: Vec<Block>, + neck_conv1: candle_nn::Conv2d, + neck_ln1: super::LayerNorm2d, + neck_conv2: candle_nn::Conv2d, + neck_ln2: super::LayerNorm2d, + pos_embed: Option<Tensor>, + span: tracing::Span, +} + +impl ImageEncoderViT { + #[allow(clippy::too_many_arguments)] + pub fn new( + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + depth: usize, + num_heads: usize, + out_chans: usize, + qkv_bias: bool, + use_rel_pos: bool, + use_abs_pos: bool, + window_size: usize, + global_attn_indexes: &[usize], + vb: VarBuilder, + ) -> Result<Self> { + let patch_embed = PatchEmbed::new( + in_chans, + embed_dim, + patch_size, + patch_size, + 0, + vb.pp("patch_embed"), + )?; + let mut blocks = Vec::with_capacity(depth); + let vb_b = vb.pp("blocks"); + for i in 0..depth { + let window_size = if global_attn_indexes.contains(&i) { + 0 + } else { + window_size + }; + let block = Block::new( + embed_dim, + num_heads, + qkv_bias, + use_rel_pos, + window_size, + (img_size / patch_size, img_size / patch_size), + vb_b.pp(i), + )?; + blocks.push(block) + } + let neck_conv1 = candle_nn::conv2d_no_bias( + embed_dim, + out_chans, + 1, + Default::default(), + vb.pp("neck.0"), + )?; + let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?; + let cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; + let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?; + let pos_embed = if use_abs_pos { + let p = vb.get( + (1, img_size / patch_size, img_size / patch_size, embed_dim), + "pos_embed", + )?; + Some(p) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit"); + Ok(Self { + patch_embed, + blocks, + neck_conv1, + neck_ln1, + neck_conv2, + neck_ln2, + pos_embed, + span, + }) + } +} + +impl Module for ImageEncoderViT { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.patch_embed.forward(xs)?; + let mut xs = match &self.pos_embed { + Some(pos_embed) => (xs + pos_embed)?, + None => xs, + }; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + xs.permute((0, 3, 1, 2))? + .apply(&self.neck_conv1)? + .apply(&self.neck_ln1)? + .apply(&self.neck_conv2)? + .apply(&self.neck_ln2) + } +} diff --git a/candle-transformers/src/models/segment_anything/mask_decoder.rs b/candle-transformers/src/models/segment_anything/mask_decoder.rs new file mode 100644 index 00000000..2a91cd44 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/mask_decoder.rs @@ -0,0 +1,239 @@ +use candle::{IndexOp, Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +use super::transformer::TwoWayTransformer; + +#[derive(Debug)] +struct MlpMaskDecoder { + layers: Vec<super::Linear>, + sigmoid_output: bool, + span: tracing::Span, +} + +impl MlpMaskDecoder { + fn new( + input_dim: usize, + hidden_dim: usize, + output_dim: usize, + num_layers: usize, + sigmoid_output: bool, + vb: VarBuilder, + ) -> Result<Self> { + let mut layers = Vec::with_capacity(num_layers); + let vb = vb.pp("layers"); + for i in 0..num_layers { + let in_dim = if i == 0 { input_dim } else { hidden_dim }; + let out_dim = if i + 1 == num_layers { + output_dim + } else { + hidden_dim + }; + let layer = super::linear(vb.pp(i), in_dim, out_dim, true)?; + layers.push(layer) + } + let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder"); + Ok(Self { + layers, + sigmoid_output, + span, + }) + } +} + +impl Module for MlpMaskDecoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for (i, layer) in self.layers.iter().enumerate() { + xs = layer.forward(&xs)?; + if i + 1 < self.layers.len() { + xs = xs.relu()? + } + } + if self.sigmoid_output { + candle_nn::ops::sigmoid(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +pub struct MaskDecoder { + iou_token: candle_nn::Embedding, + mask_tokens: candle_nn::Embedding, + iou_prediction_head: MlpMaskDecoder, + output_upscaling_conv1: candle_nn::ConvTranspose2d, + output_upscaling_ln: super::LayerNorm2d, + output_upscaling_conv2: candle_nn::ConvTranspose2d, + num_mask_tokens: usize, + output_hypernetworks_mlps: Vec<MlpMaskDecoder>, + transformer: TwoWayTransformer, + span: tracing::Span, +} + +impl MaskDecoder { + pub fn new( + transformer_dim: usize, + num_multimask_outputs: usize, + iou_head_depth: usize, + iou_head_hidden_dim: usize, + vb: VarBuilder, + ) -> Result<Self> { + let num_mask_tokens = num_multimask_outputs + 1; + let iou_prediction_head = MlpMaskDecoder::new( + transformer_dim, + iou_head_hidden_dim, + num_mask_tokens, + iou_head_depth, + false, + vb.pp("iou_prediction_head"), + )?; + let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?; + let mask_tokens = + candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?; + let cfg = candle_nn::ConvTranspose2dConfig { + stride: 2, + ..Default::default() + }; + let output_upscaling_conv1 = candle_nn::conv_transpose2d( + transformer_dim, + transformer_dim / 4, + 2, + cfg, + vb.pp("output_upscaling.0"), + )?; + let output_upscaling_ln = + super::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; + let output_upscaling_conv2 = candle_nn::conv_transpose2d( + transformer_dim / 4, + transformer_dim / 8, + 2, + cfg, + vb.pp("output_upscaling.3"), + )?; + let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens); + let vb_o = vb.pp("output_hypernetworks_mlps"); + for i in 0..num_mask_tokens { + let mlp = MlpMaskDecoder::new( + transformer_dim, + transformer_dim, + transformer_dim / 8, + 3, + false, + vb_o.pp(i), + )?; + output_hypernetworks_mlps.push(mlp) + } + let transformer = TwoWayTransformer::new( + /* depth */ 2, + /* embedding_dim */ transformer_dim, + /* num_heads */ 8, + /* mlp_dim */ 2048, + vb.pp("transformer"), + )?; + let span = tracing::span!(tracing::Level::TRACE, "mask-decoder"); + Ok(Self { + iou_token, + mask_tokens, + iou_prediction_head, + output_upscaling_conv1, + output_upscaling_ln, + output_upscaling_conv2, + num_mask_tokens, + output_hypernetworks_mlps, + transformer, + span, + }) + } + + pub fn forward( + &self, + image_embeddings: &Tensor, + image_pe: &Tensor, + sparse_prompt_embeddings: &Tensor, + dense_prompt_embeddings: &Tensor, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); + let (masks, iou_pred) = self.predict_masks( + image_embeddings, + image_pe, + sparse_prompt_embeddings, + dense_prompt_embeddings, + )?; + let masks = if multimask_output { + masks.i((.., 1..))? + } else { + masks.i((.., 0..1))? + }; + let iou_pred = if multimask_output { + iou_pred.i((.., 1..))? + } else { + iou_pred.i((.., 0..1))? + }; + Ok((masks, iou_pred)) + } + + fn predict_masks( + &self, + image_embeddings: &Tensor, + image_pe: &Tensor, + sparse_prompt_embeddings: &Tensor, + dense_prompt_embeddings: &Tensor, + ) -> Result<(Tensor, Tensor)> { + // Concatenate ouput tokens. + let output_tokens = Tensor::cat( + &[self.iou_token.embeddings(), self.mask_tokens.embeddings()], + 0, + )?; + let (d1, d2) = output_tokens.dims2()?; + let output_tokens = + output_tokens + .unsqueeze(0)? + .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?; + let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?; + + // Expand per-image data in batch direction to be per mask + let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?; + let src = src.broadcast_add(dense_prompt_embeddings)?; + let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?; + let (b, c, h, w) = src.dims4()?; + + // Run the transformer + let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?; + let iou_token_out = hs.i((.., 0))?; + let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?; + + // Upscale mask embeddings and predict masks using the masks tokens. + let src = src.transpose(1, 2)?.reshape((b, c, h, w))?; + let upscaled_embedding = self + .output_upscaling_conv1 + .forward(&src)? + .apply(&self.output_upscaling_ln)? + .gelu()? + .apply(&self.output_upscaling_conv2)? + .gelu()?; + let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens); + for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() { + let h = mlp.forward(&mask_tokens_out.i((.., i))?)?; + hyper_in_list.push(h) + } + let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?; + let (b, c, h, w) = upscaled_embedding.dims4()?; + let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?; + let masks = masks.reshape((b, (), h, w))?; + + // Generate mask quality predictions. + let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?; + Ok((masks, iou_pred)) + } +} + +// Equivalent to torch.repeat_interleave +fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> { + let img = img.unsqueeze(dim + 1)?; + let mut dims = img.dims().to_vec(); + dims[dim + 1] = repeats; + img.broadcast_as(dims)?.flatten(dim, dim + 1) +} diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs new file mode 100644 index 00000000..c29db70a --- /dev/null +++ b/candle-transformers/src/models/segment_anything/mod.rs @@ -0,0 +1,100 @@ +use candle::{Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +pub mod image_encoder; +pub mod mask_decoder; +pub mod prompt_encoder; +pub mod sam; +pub mod tiny_vit; +pub mod transformer; + +pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { + let inner = if bias { + candle_nn::linear(in_dim, out_dim, vb)? + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb)? + }; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { inner, span }) +} + +#[derive(Debug)] +pub struct LayerNorm2d { + weight: Tensor, + bias: Tensor, + num_channels: usize, + eps: f64, +} + +impl LayerNorm2d { + pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(num_channels, "weight")?; + let bias = vb.get(num_channels, "bias")?; + Ok(Self { + weight, + bias, + num_channels, + eps, + }) + } +} + +impl Module for LayerNorm2d { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let u = xs.mean_keepdim(1)?; + let xs = xs.broadcast_sub(&u)?; + let s = xs.sqr()?.mean_keepdim(1)?; + let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; + xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? + .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) + } +} + +#[derive(Debug)] +pub struct MlpBlock { + lin1: Linear, + lin2: Linear, + activation: candle_nn::Activation, + span: tracing::Span, +} + +impl MlpBlock { + pub fn new( + embedding_dim: usize, + mlp_dim: usize, + activation: candle_nn::Activation, + vb: VarBuilder, + ) -> Result<Self> { + let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; + let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); + Ok(Self { + lin1, + lin2, + activation, + span, + }) + } +} + +impl Module for MlpBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.lin1)? + .apply(&self.activation)? + .apply(&self.lin2) + } +} + +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} diff --git a/candle-transformers/src/models/segment_anything/prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs new file mode 100644 index 00000000..9d0074b1 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs @@ -0,0 +1,239 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +struct PostionEmbeddingRandom { + positional_encoding_gaussian_matrix: Tensor, +} + +impl PostionEmbeddingRandom { + fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> { + let positional_encoding_gaussian_matrix = + vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?; + Ok(Self { + positional_encoding_gaussian_matrix, + }) + } + + fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> { + let coords = coords.affine(2., -1.)?; + let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?; + let coords = (coords * (2. * std::f64::consts::PI))?; + Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1) + } + + fn forward(&self, h: usize, w: usize) -> Result<Tensor> { + let device = self.positional_encoding_gaussian_matrix.device(); + let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?; + let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?; + let x_embed = (x_embed / w as f64)? + .reshape((1, ()))? + .broadcast_as((h, w))?; + let y_embed = (y_embed / h as f64)? + .reshape(((), 1))? + .broadcast_as((h, w))?; + let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?; + self.pe_encoding(&coords)?.permute((2, 0, 1)) + } + + fn forward_with_coords( + &self, + coords_input: &Tensor, + image_size: (usize, usize), + ) -> Result<Tensor> { + let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?; + let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?; + let c = coords_input.dim(D::Minus1)?; + let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?; + let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?; + self.pe_encoding(&coords) + } +} + +#[derive(Debug)] +pub struct PromptEncoder { + pe_layer: PostionEmbeddingRandom, + point_embeddings: Vec<candle_nn::Embedding>, + not_a_point_embed: candle_nn::Embedding, + mask_downscaling_conv1: candle_nn::Conv2d, + mask_downscaling_ln1: super::LayerNorm2d, + mask_downscaling_conv2: candle_nn::Conv2d, + mask_downscaling_ln2: super::LayerNorm2d, + mask_downscaling_conv3: candle_nn::Conv2d, + no_mask_embed: candle_nn::Embedding, + image_embedding_size: (usize, usize), + input_image_size: (usize, usize), + embed_dim: usize, + span: tracing::Span, +} + +impl PromptEncoder { + pub fn new( + embed_dim: usize, + image_embedding_size: (usize, usize), + input_image_size: (usize, usize), + mask_in_chans: usize, + vb: VarBuilder, + ) -> Result<Self> { + let num_points_embeddings = 4; + let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?; + let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?; + let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?; + let cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let mask_downscaling_conv1 = + candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?; + let mask_downscaling_conv2 = candle_nn::conv2d( + mask_in_chans / 4, + mask_in_chans, + 2, + cfg, + vb.pp("mask_downscaling.3"), + )?; + let mask_downscaling_conv3 = candle_nn::conv2d( + mask_in_chans, + embed_dim, + 1, + Default::default(), + vb.pp("mask_downscaling.6"), + )?; + let mask_downscaling_ln1 = + super::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; + let mask_downscaling_ln2 = + super::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; + let mut point_embeddings = Vec::with_capacity(num_points_embeddings); + let vb_e = vb.pp("point_embeddings"); + for i in 0..num_points_embeddings { + let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?; + point_embeddings.push(emb) + } + let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder"); + Ok(Self { + pe_layer, + point_embeddings, + not_a_point_embed, + mask_downscaling_conv1, + mask_downscaling_ln1, + mask_downscaling_conv2, + mask_downscaling_ln2, + mask_downscaling_conv3, + no_mask_embed, + image_embedding_size, + input_image_size, + embed_dim, + span, + }) + } + + pub fn get_dense_pe(&self) -> Result<Tensor> { + self.pe_layer + .forward(self.image_embedding_size.0, self.image_embedding_size.1)? + .unsqueeze(0) + } + + fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> { + masks + .apply(&self.mask_downscaling_conv1)? + .apply(&self.mask_downscaling_ln1)? + .gelu()? + .apply(&self.mask_downscaling_conv2)? + .apply(&self.mask_downscaling_ln2)? + .gelu()? + .apply(&self.mask_downscaling_conv3) + } + + fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> { + let points = (points + 0.5)?; + let dev = points.device(); + let (points, labels) = if pad { + let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?; + let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?; + let points = Tensor::cat(&[&points, &padding_point], 1)?; + let labels = Tensor::cat(&[labels, &padding_label], 1)?; + (points, labels) + } else { + (points, labels.clone()) + }; + let point_embedding = self + .pe_layer + .forward_with_coords(&points, self.input_image_size)?; + let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?; + let zeros = point_embedding.zeros_like()?; + let point_embedding = labels.lt(0f32)?.where_cond( + &self + .not_a_point_embed + .embeddings() + .broadcast_as(zeros.shape())?, + &point_embedding, + )?; + let labels0 = labels.eq(0f32)?.where_cond( + &self.point_embeddings[0] + .embeddings() + .broadcast_as(zeros.shape())?, + &zeros, + )?; + let point_embedding = (point_embedding + labels0)?; + let labels1 = labels.eq(1f32)?.where_cond( + &self.point_embeddings[1] + .embeddings() + .broadcast_as(zeros.shape())?, + &zeros, + )?; + let point_embedding = (point_embedding + labels1)?; + Ok(point_embedding) + } + + fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> { + let boxes = (boxes + 0.5)?; + let coords = boxes.reshape(((), 2, 2))?; + let corner_embedding = self + .pe_layer + .forward_with_coords(&coords, self.input_image_size)?; + let ce1 = corner_embedding.i((.., 0))?; + let ce2 = corner_embedding.i((.., 1))?; + let ce1 = (ce1 + self.point_embeddings[2].embeddings())?; + let ce2 = (ce2 + self.point_embeddings[3].embeddings())?; + Tensor::cat(&[&ce1, &ce2], 1) + } + + pub fn forward( + &self, + points: Option<(&Tensor, &Tensor)>, + boxes: Option<&Tensor>, + masks: Option<&Tensor>, + ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); + let se_points = match points { + Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?), + None => None, + }; + let se_boxes = match boxes { + Some(boxes) => Some(self.embed_boxes(boxes)?), + None => None, + }; + let sparse_embeddings = match (se_points, se_boxes) { + (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?, + (Some(se_points), None) => se_points, + (None, Some(se_boxes)) => se_boxes, + (None, None) => { + Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)? + } + }; + + let dense_embeddings = match masks { + None => { + let emb = self.no_mask_embed.embeddings(); + emb.reshape((1, (), 1, 1))?.expand(( + 1, + emb.elem_count(), + self.image_embedding_size.0, + self.image_embedding_size.1, + ))? + } + Some(masks) => self.embed_masks(masks)?, + }; + Ok((sparse_embeddings, dense_embeddings)) + } +} diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs new file mode 100644 index 00000000..c40473e3 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -0,0 +1,411 @@ +use candle::{DType, IndexOp, Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +use super::image_encoder::ImageEncoderViT; +use super::mask_decoder::MaskDecoder; +use super::prompt_encoder::PromptEncoder; +use super::tiny_vit::{tiny_vit_5m, TinyViT}; + +const PROMPT_EMBED_DIM: usize = 256; +pub const IMAGE_SIZE: usize = 1024; +const VIT_PATCH_SIZE: usize = 16; +const PRED_IOU_THRESH: f32 = 0.88; +const STABILITY_SCORE_OFFSET: f32 = 1.0; +const STABILITY_SCORE_THRESHOLD: f32 = 0.95; +const MODEL_MASK_THRESHOLD: f32 = 0.0; +const CROP_NMS_THRESH: f32 = 0.7; + +#[derive(Debug)] +enum ImageEncoder { + Original(ImageEncoderViT), + TinyViT(TinyViT), +} + +impl Module for ImageEncoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match self { + Self::Original(vit) => vit.forward(xs), + Self::TinyViT(vit) => vit.forward(xs), + } + } +} + +#[derive(Debug)] +pub struct Sam { + image_encoder: ImageEncoder, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: Tensor, + pixel_std: Tensor, +} + +impl Sam { + pub fn new( + encoder_embed_dim: usize, + encoder_depth: usize, + encoder_num_heads: usize, + encoder_global_attn_indexes: &[usize], + vb: VarBuilder, + ) -> Result<Self> { + let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; + + let image_encoder = ImageEncoderViT::new( + IMAGE_SIZE, + VIT_PATCH_SIZE, + 3, + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + PROMPT_EMBED_DIM, + /* qkv_bias */ true, + /* use_rel_pos */ true, + /* use_abs_pos */ true, + /* window_size */ 14, + /* global_attn_indexes */ encoder_global_attn_indexes, + vb.pp("image_encoder"), + )?; + let prompt_encoder = PromptEncoder::new( + PROMPT_EMBED_DIM, + (image_embedding_size, image_embedding_size), + (IMAGE_SIZE, IMAGE_SIZE), + 16, + vb.pp("prompt_encoder"), + )?; + let mask_decoder = MaskDecoder::new( + PROMPT_EMBED_DIM, + /* num_multitask_outputs */ 3, + /* iou_head_depth */ 3, + /* iou_head_hidden_dim */ 256, + vb.pp("mask_decoder"), + )?; + let pixel_mean = + Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; + let pixel_std = + Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; + Ok(Self { + image_encoder: ImageEncoder::Original(image_encoder), + prompt_encoder, + mask_decoder, + pixel_std, + pixel_mean, + }) + } + + pub fn new_tiny(vb: VarBuilder) -> Result<Self> { + let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; + + let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?; + let prompt_encoder = PromptEncoder::new( + PROMPT_EMBED_DIM, + (image_embedding_size, image_embedding_size), + (IMAGE_SIZE, IMAGE_SIZE), + 16, + vb.pp("prompt_encoder"), + )?; + let mask_decoder = MaskDecoder::new( + PROMPT_EMBED_DIM, + /* num_multitask_outputs */ 3, + /* iou_head_depth */ 3, + /* iou_head_hidden_dim */ 256, + vb.pp("mask_decoder"), + )?; + let pixel_mean = + Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; + let pixel_std = + Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; + Ok(Self { + image_encoder: ImageEncoder::TinyViT(image_encoder), + prompt_encoder, + mask_decoder, + pixel_std, + pixel_mean, + }) + } + + pub fn forward( + &self, + img: &Tensor, + point: Option<(f64, f64)>, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let (_c, original_h, original_w) = img.dims3()?; + let img = self.preprocess(img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + let image_pe = self.prompt_encoder.get_dense_pe()?; + let points = match point { + None => None, + Some((x, y)) => { + let points = Tensor::new( + &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]], + img.device(), + )?; + let labels = Tensor::ones((1, 1), DType::F32, img.device())?; + Some((points, labels)) + } + }; + let points = points.as_ref().map(|(x, y)| (x, y)); + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder.forward(points, None, None)?; + let (low_res_mask, iou_predictions) = self.mask_decoder.forward( + &img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + multimask_output, + )?; + let mask = low_res_mask + .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)? + .get(0)? + .i((.., ..original_h, ..original_w))?; + Ok((mask, iou_predictions)) + } + + pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> { + let img = img + .broadcast_mul(&self.pixel_std)? + .broadcast_add(&self.pixel_mean)?; + img.maximum(&img.zeros_like()?)? + .minimum(&(img.ones_like()? * 255.)?) + } + + pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> { + let (_c, h, w) = img.dims3()?; + let img = img + .to_dtype(DType::F32)? + .broadcast_sub(&self.pixel_mean)? + .broadcast_div(&self.pixel_std)?; + if h > IMAGE_SIZE || w > IMAGE_SIZE { + candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}") + } + let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?; + img.pad_with_zeros(2, 0, IMAGE_SIZE - w) + } + + fn process_crop( + &self, + img: &Tensor, + cb: CropBox, + point_grids: &[(f64, f64)], + ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> { + // Crop the image and calculate embeddings. + let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?; + let img = self.preprocess(&img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + + let crop_w = cb.x1 - cb.x0; + let crop_h = cb.y1 - cb.y0; + + // Generate masks for this crop. + let image_pe = self.prompt_encoder.get_dense_pe()?; + let points = point_grids + .iter() + .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32]) + .collect::<Vec<_>>(); + + let mut bboxes = Vec::new(); + for points in points.chunks(64) { + // Run the model on this batch. + let points_len = points.len(); + let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?; + let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?; + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder + .forward(Some((&in_points, &in_labels)), None, None)?; + + let (low_res_mask, iou_predictions) = self.mask_decoder.forward( + &img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + /* multimask_output */ true, + )?; + let low_res_mask = low_res_mask.flatten(0, 1)?; + let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?; + let dev = low_res_mask.device(); + + for (i, iou) in iou_predictions.iter().enumerate() { + // Filter by predicted IoU. + if *iou < PRED_IOU_THRESH { + continue; + } + let low_res_mask = low_res_mask.get(i)?; + + // Calculate stability score. + let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let intersections = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let unions = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let stability_score = intersections / unions; + if stability_score < STABILITY_SCORE_THRESHOLD { + continue; + } + + // Threshold masks and calculate boxes. + let low_res_mask = low_res_mask + .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)? + .to_dtype(DType::U32)?; + let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?; + let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?; + let min_max_x = min_max_indexes(&low_res_mask_per_x); + let min_max_y = min_max_indexes(&low_res_mask_per_y); + if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) { + let bbox = crate::object_detection::Bbox { + xmin: x0 as f32, + ymin: y0 as f32, + xmax: x1 as f32, + ymax: y1 as f32, + confidence: *iou, + data: low_res_mask, + }; + bboxes.push(bbox); + } + // TODO: + // Filter boxes that touch crop boundaries + // Compress to RLE. + } + } + + let mut bboxes = vec![bboxes]; + // Remove duplicates within this crop. + crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH); + + // TODO: Return to the original image frame. + Ok(bboxes.remove(0)) + } + + pub fn generate_masks( + &self, + img: &Tensor, + points_per_side: usize, + crop_n_layer: usize, + crop_overlap_ratio: f64, + crop_n_points_downscale_factor: usize, + ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> { + let (_c, h, w) = img.dims3()?; + let point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layer, + crop_n_points_downscale_factor, + ); + let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio); + let mut bboxes = Vec::new(); + for crop_box in crop_boxes.into_iter() { + let layer_idx = crop_box.layer_idx; + let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?; + bboxes.extend(b) + } + // TODO: remove duplicates + Ok(bboxes) + } +} + +// Return the first and last indexes i for which values[i] > 0 +fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> { + let (mut min_i, mut max_i) = (usize::MAX, usize::MIN); + for (i, &s) in values.iter().enumerate() { + if s == 0 { + continue; + } + min_i = usize::min(i, min_i); + max_i = usize::max(i, max_i); + } + if max_i < min_i { + None + } else { + Some((min_i, max_i)) + } +} + +#[derive(Debug)] +struct CropBox { + x0: usize, + y0: usize, + x1: usize, + y1: usize, + layer_idx: usize, +} + +impl CropBox { + fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self { + Self { + x0, + y0, + x1, + y1, + layer_idx, + } + } +} + +fn generate_crop_boxes( + (im_h, im_w): (usize, usize), + n_layers: usize, + overlap_ratio: f64, +) -> Vec<CropBox> { + fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize { + f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize + } + + let short_side = usize::min(im_h, im_w); + + let mut crop_boxes = Vec::new(); + + // Original image. + crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0)); + + for layer_idx in 1..=n_layers { + let n_crops_per_side = 1 << layer_idx; + let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize; + let crop_w = crop_len(im_w, n_crops_per_side, overlap); + let crop_h = crop_len(im_w, n_crops_per_side, overlap); + + for i_x in 0..n_crops_per_side { + let x0 = (crop_w - overlap) * i_x; + for i_y in 0..n_crops_per_side { + let y0 = (crop_h - overlap) * i_y; + let x1 = usize::min(im_w, x0 + crop_w); + let y1 = usize::min(im_h, y0 + crop_h); + crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx)); + } + } + } + + crop_boxes +} + +// Generates a 2D grid of points evenly spaced in [0,1]x[0,1]. +fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> { + let offset = 1f64 / (2 * n_per_side) as f64; + let mut points = Vec::with_capacity(n_per_side * n_per_side); + for i_x in 0..n_per_side { + let x = offset + i_x as f64 / n_per_side as f64; + for i_y in 0..n_per_side { + let y = offset + i_y as f64 / n_per_side as f64; + points.push((x, y)) + } + } + points +} + +fn build_all_layer_point_grids( + n_per_side: usize, + n_layers: usize, + scale_per_layer: usize, +) -> Vec<Vec<(f64, f64)>> { + let mut points_by_layer = Vec::with_capacity(n_layers + 1); + for i in 0..=n_layers { + let n_points = n_per_side / scale_per_layer.pow(i as u32); + points_by_layer.push(build_point_grid(n_points)) + } + points_by_layer +} diff --git a/candle-transformers/src/models/segment_anything/tiny_vit.rs b/candle-transformers/src/models/segment_anything/tiny_vit.rs new file mode 100644 index 00000000..cd2936ab --- /dev/null +++ b/candle-transformers/src/models/segment_anything/tiny_vit.rs @@ -0,0 +1,633 @@ +// Adapted from: +// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{Conv2dConfig, Module, VarBuilder}; + +const MBCONV_EXPAND_RATIO: usize = 4; +const MLP_RATIO: usize = 4; +const LOCAL_CONV_SIZE: usize = 3; +const IMG_SIZE: usize = 1024; +const IN_CHANNELS: usize = 3; + +#[derive(Debug)] +struct Conv2dBN { + c: candle_nn::Conv2d, + bn: candle_nn::BatchNorm, + span: tracing::Span, +} + +impl Conv2dBN { + fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> { + let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?; + let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?; + let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn"); + Ok(Self { c, bn, span }) + } +} + +impl Module for Conv2dBN { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.c)?.apply(&self.bn) + } +} + +#[derive(Debug)] +struct PatchEmbed { + conv1: Conv2dBN, + conv2: Conv2dBN, + span: tracing::Span, +} + +impl PatchEmbed { + fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + stride: 2, + padding: 1, + ..Default::default() + }; + let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?; + let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?; + let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); + Ok(Self { conv1, conv2, span }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2) + } +} + +#[derive(Debug)] +struct MBConv { + conv1: Conv2dBN, + conv2: Conv2dBN, + conv3: Conv2dBN, + span: tracing::Span, +} + +impl MBConv { + fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> { + let hidden = in_ * expand_ratio; + let cfg2 = candle_nn::Conv2dConfig { + padding: 1, + groups: hidden, + ..Default::default() + }; + let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?; + let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?; + let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?; + let span = tracing::span!(tracing::Level::TRACE, "mb-conv"); + Ok(Self { + conv1, + conv2, + conv3, + span, + }) + } +} + +impl Module for MBConv { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let shortcut = xs; + let xs = xs + .apply(&self.conv1)? + .gelu()? + .apply(&self.conv2)? + .gelu()? + .apply(&self.conv3)?; + (xs + shortcut)?.gelu() + } +} + +#[derive(Debug)] +struct PatchMerging { + conv1: Conv2dBN, + conv2: Conv2dBN, + conv3: Conv2dBN, + input_resolution: (usize, usize), + span: tracing::Span, +} + +impl PatchMerging { + fn new( + input_resolution: (usize, usize), + dim: usize, + out: usize, + vb: VarBuilder, + ) -> Result<Self> { + let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 }; + let cfg2 = candle_nn::Conv2dConfig { + padding: 1, + stride, + groups: out, + ..Default::default() + }; + let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?; + let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?; + let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?; + let span = tracing::span!(tracing::Level::TRACE, "patch-merging"); + Ok(Self { + conv1, + conv2, + conv3, + input_resolution, + span, + }) + } +} + +impl Module for PatchMerging { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = if xs.rank() == 3 { + let (h, w) = self.input_resolution; + let b = xs.dim(0)?; + xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))? + } else { + xs.clone() + }; + xs.apply(&self.conv1)? + .gelu()? + .apply(&self.conv2)? + .gelu()? + .apply(&self.conv3)? + .flatten_from(2)? + .transpose(1, 2) + } +} + +#[derive(Debug)] +struct ConvLayer { + blocks: Vec<MBConv>, + downsample: Option<PatchMerging>, + span: tracing::Span, +} + +impl ConvLayer { + fn new( + dim: usize, + out: usize, + input_resolution: (usize, usize), + depth: usize, + downsample: bool, + conv_expand_ratio: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb_b = vb.pp("blocks"); + let mut blocks = Vec::with_capacity(depth); + for index in 0..depth { + let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?; + blocks.push(block) + } + let downsample = if downsample { + let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?; + Some(downsample) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "conv-layer"); + Ok(Self { + blocks, + downsample, + span, + }) + } +} + +impl Module for ConvLayer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + match &self.downsample { + None => Ok(xs), + Some(downsample) => downsample.forward(&xs), + } + } +} + +#[derive(Debug)] +struct Mlp { + norm: candle_nn::LayerNorm, + fc1: super::Linear, + fc2: super::Linear, + span: tracing::Span, +} + +impl Mlp { + fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> { + let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?; + let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?; + let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp"); + Ok(Self { + norm, + fc1, + fc2, + span, + }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.norm)? + .apply(&self.fc1)? + .gelu()? + .apply(&self.fc2) + } +} + +#[derive(Debug)] +struct Attention { + norm: candle_nn::LayerNorm, + qkv: super::Linear, + proj: super::Linear, + ab: Tensor, + key_dim: usize, + num_heads: usize, + d: usize, + dh: usize, + scale: f64, + span: tracing::Span, + span_matmul: tracing::Span, + span_softmax: tracing::Span, +} + +impl Attention { + fn new( + dim: usize, + key_dim: usize, + num_heads: usize, + attn_ratio: usize, + resolution: (usize, usize), + vb: VarBuilder, + ) -> Result<Self> { + let d = attn_ratio * key_dim; + let dh = d * num_heads; + let nh_kd = key_dim * num_heads; + let h = dh + nh_kd * 2; + let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?; + let qkv = super::linear(vb.pp("qkv"), dim, h, true)?; + let proj = super::linear(vb.pp("proj"), dh, dim, true)?; + + let points = (0..resolution.0) + .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64))) + .collect::<Vec<_>>(); + let mut idxs = Vec::with_capacity(points.len() * points.len()); + let mut attention_offsets = std::collections::HashMap::new(); + for &(x1, y1) in points.iter() { + for &(x2, y2) in points.iter() { + let offset = ((x2 - x1).abs(), (y2 - y1).abs()); + let l = attention_offsets.len(); + let idx = attention_offsets.entry(offset).or_insert(l); + idxs.push(*idx as u32) + } + } + let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?; + let idxs = Tensor::new(idxs, attention_biases.device())?; + let ab = + attention_biases + .index_select(&idxs, 1)? + .reshape(((), points.len(), points.len()))?; + let span = tracing::span!(tracing::Level::TRACE, "attention"); + let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); + Ok(Self { + norm, + qkv, + proj, + ab, + key_dim, + num_heads, + d, + dh, + scale: 1f64 / (key_dim as f64).sqrt(), + span, + span_matmul, + span_softmax, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (b, n, _) = xs.dims3()?; + let xs = xs.apply(&self.norm)?; + let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?; + let q = qkv + .narrow(D::Minus1, 0, self.key_dim)? + .permute((0, 2, 1, 3))? + .contiguous()?; + let k = qkv + .narrow(D::Minus1, self.key_dim, self.key_dim)? + .permute((0, 2, 1, 3))? + .contiguous()?; + let v = qkv + .narrow(D::Minus1, 2 * self.key_dim, self.d)? + .permute((0, 2, 1, 3))? + .contiguous()?; + let attn = { + let _enter = self.span_matmul.enter(); + (q.matmul(&k.t()?)? * self.scale)? + }; + let attn = attn.broadcast_add(&self.ab)?; + let attn = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn)? + }; + let attn = { + let _enter = self.span_matmul.enter(); + attn.matmul(&v)? + }; + attn.transpose(1, 2)? + .reshape((b, n, self.dh))? + .apply(&self.proj) + } +} + +#[derive(Debug)] +struct TinyViTBlock { + attn: Attention, + local_conv: Conv2dBN, + mlp: Mlp, + window_size: usize, + input_resolution: (usize, usize), + span: tracing::Span, +} + +impl TinyViTBlock { + fn new( + dim: usize, + input_resolution: (usize, usize), + num_heads: usize, + window_size: usize, + vb: VarBuilder, + ) -> Result<Self> { + let head_dim = dim / num_heads; + let attn = Attention::new( + dim, + head_dim, + num_heads, + 1, + (window_size, window_size), + vb.pp("attn"), + )?; + let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?; + let cfg = candle_nn::Conv2dConfig { + padding: LOCAL_CONV_SIZE / 2, + groups: dim, + ..Default::default() + }; + let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?; + let span = tracing::span!(tracing::Level::TRACE, "attention"); + Ok(Self { + attn, + local_conv, + mlp, + window_size, + input_resolution, + span, + }) + } +} + +impl Module for TinyViTBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (h, w) = self.input_resolution; + let (b, l, c) = xs.dims3()?; + let res_x = xs; + let xs = if h == self.window_size && w == self.window_size { + self.attn.forward(xs)? + } else { + let xs = xs.reshape((b, h, w, c))?; + let pad_b = (self.window_size - h % self.window_size) % self.window_size; + let pad_r = (self.window_size - w % self.window_size) % self.window_size; + + let xs = if pad_b > 0 { + xs.pad_with_zeros(1, 0, pad_b)? + } else { + xs + }; + let xs = if pad_r > 0 { + xs.pad_with_zeros(2, 0, pad_r)? + } else { + xs + }; + let (p_h, p_w) = (h + pad_b, w + pad_r); + let n_h = p_h / self.window_size; + let n_w = p_w / self.window_size; + let xs = xs + .reshape((b, n_h, self.window_size, n_w, self.window_size, c))? + .transpose(2, 3)? + .reshape((b * n_h * n_w, self.window_size * self.window_size, c))?; + let xs = self.attn.forward(&xs)?; + let xs = xs + .reshape((b, n_h, n_w, self.window_size, self.window_size, c))? + .transpose(2, 3)? + .reshape((b, p_h, p_w, c))?; + let xs = if pad_r > 0 { + xs.i((.., .., ..w))?.contiguous()? + } else { + xs + }; + let xs = if pad_b > 0 { + xs.i((.., ..h, ..))?.contiguous()? + } else { + xs + }; + xs.reshape((b, l, c))? + }; + let xs = (xs + res_x)?; + let xs = xs + .transpose(1, 2)? + .reshape((b, c, h, w))? + .apply(&self.local_conv)? + .reshape((b, c, l))? + .transpose(1, 2)?; + &xs + self.mlp.forward(&xs)? + } +} + +#[derive(Debug)] +struct BasicLayer { + blocks: Vec<TinyViTBlock>, + downsample: Option<PatchMerging>, + span: tracing::Span, +} + +impl BasicLayer { + #[allow(clippy::too_many_arguments)] + fn new( + dim: usize, + input_resolution: (usize, usize), + depth: usize, + num_heads: usize, + window_size: usize, + downsample: bool, + out: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb_b = vb.pp("blocks"); + let mut blocks = Vec::with_capacity(depth); + for index in 0..depth { + let block = TinyViTBlock::new( + dim, + input_resolution, + num_heads, + window_size, + vb_b.pp(index), + )?; + blocks.push(block) + } + let downsample = if downsample { + let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?; + Some(downsample) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "basic-layer"); + Ok(Self { + blocks, + downsample, + span, + }) + } +} + +impl Module for BasicLayer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + match &self.downsample { + None => Ok(xs), + Some(downsample) => downsample.forward(&xs), + } + } +} + +#[derive(Debug)] +pub struct TinyViT { + patch_embed: PatchEmbed, + layer0: ConvLayer, + layers: Vec<BasicLayer>, + // norm_head: candle_nn::LayerNorm, + // head: candle_nn::Linear, + neck_conv1: candle_nn::Conv2d, + neck_ln1: super::LayerNorm2d, + neck_conv2: candle_nn::Conv2d, + neck_ln2: super::LayerNorm2d, + span: tracing::Span, + span_neck: tracing::Span, +} + +impl TinyViT { + pub fn new( + embed_dims: &[usize], + depths: &[usize], + num_heads: &[usize], + window_sizes: &[usize], + _num_classes: usize, + vb: VarBuilder, + ) -> Result<Self> { + let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?; + let patches_resolution = IMG_SIZE / 4; + + let vb_l = vb.pp("layers"); + let layer0 = ConvLayer::new( + /* dim */ embed_dims[0], + /* out */ embed_dims[1], + /* input_resolution */ (patches_resolution, patches_resolution), + /* depth */ depths[0], + /* downsample */ true, + /* conv_expand_ratio */ MBCONV_EXPAND_RATIO, + vb_l.pp(0), + )?; + + let num_layers = embed_dims.len(); + let mut layers = Vec::with_capacity(num_layers - 1); + for i_layer in 1..num_layers { + let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2)); + let layer = BasicLayer::new( + /* dim */ embed_dims[i_layer], + /* input_resolution */ (patches_resolution, patches_resolution), + /* depth */ depths[i_layer], + /* num_heads */ num_heads[i_layer], + /* window_size */ window_sizes[i_layer], + /* downsample */ i_layer < num_layers - 1, + /* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)], + vb_l.pp(i_layer), + )?; + layers.push(layer) + } + + let last_embed_dim = embed_dims[embed_dims.len() - 1]; + // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?; + // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?; + let neck_conv1 = + candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?; + let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?; + let cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?; + let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?; + + let span = tracing::span!(tracing::Level::TRACE, "tiny-vit"); + let span_neck = tracing::span!(tracing::Level::TRACE, "neck"); + Ok(Self { + patch_embed, + layer0, + layers, + neck_conv1, + neck_ln1, + neck_conv2, + neck_ln2, + span, + span_neck, + }) + } +} + +impl Module for TinyViT { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.patch_embed.forward(xs)?; + let mut xs = self.layer0.forward(&xs)?; + for layer in self.layers.iter() { + xs = layer.forward(&xs)? + } + let (b, _, c) = xs.dims3()?; + let _enter = self.span_neck.enter(); + xs.reshape((b, 64, 64, c))? + .permute((0, 3, 1, 2))? + .apply(&self.neck_conv1)? + .apply(&self.neck_ln1)? + .apply(&self.neck_conv2)? + .apply(&self.neck_ln2) + } +} + +pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> { + TinyViT::new( + /* embed_dims */ &[64, 128, 160, 320], + /* depths */ &[2, 2, 6, 2], + /* num_heads */ &[2, 4, 5, 10], + /* window_sizes */ &[7, 7, 14, 7], + /* num_classes */ 1000, + vb, + ) +} diff --git a/candle-transformers/src/models/segment_anything/transformer.rs b/candle-transformers/src/models/segment_anything/transformer.rs new file mode 100644 index 00000000..80efb38c --- /dev/null +++ b/candle-transformers/src/models/segment_anything/transformer.rs @@ -0,0 +1,221 @@ +use candle::{Result, Tensor}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +#[derive(Debug)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + out_proj: Linear, + num_heads: usize, +} + +impl Attention { + fn new( + embedding_dim: usize, + num_heads: usize, + downsample_rate: usize, + vb: VarBuilder, + ) -> Result<Self> { + let internal_dim = embedding_dim / downsample_rate; + let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?; + let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?; + let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?; + let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + out_proj, + num_heads, + }) + } + + fn separate_heads(&self, x: &Tensor) -> Result<Tensor> { + let (b, n, c) = x.dims3()?; + x.reshape((b, n, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? + .contiguous() + } + + fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> { + let (b, n_heads, n_tokens, c_per_head) = x.dims4()?; + x.transpose(1, 2)? + .reshape((b, n_tokens, n_heads * c_per_head)) + } + + fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { + let q = self.q_proj.forward(&q.contiguous()?)?; + let k = self.k_proj.forward(&k.contiguous()?)?; + let v = self.v_proj.forward(&v.contiguous()?)?; + + let q = self.separate_heads(&q)?; + let k = self.separate_heads(&k)?; + let v = self.separate_heads(&v)?; + + let (_, _, _, c_per_head) = q.dims4()?; + let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?; + let attn = candle_nn::ops::softmax_last_dim(&attn)?; + + let out = attn.matmul(&v)?; + self.recombine_heads(&out)?.apply(&self.out_proj) + } +} + +#[derive(Debug)] +struct TwoWayAttentionBlock { + self_attn: Attention, + norm1: LayerNorm, + cross_attn_token_to_image: Attention, + norm2: LayerNorm, + mlp: super::MlpBlock, + norm3: LayerNorm, + norm4: LayerNorm, + cross_attn_image_to_token: Attention, + skip_first_layer_pe: bool, +} + +impl TwoWayAttentionBlock { + fn new( + embedding_dim: usize, + num_heads: usize, + mlp_dim: usize, + skip_first_layer_pe: bool, + vb: VarBuilder, + ) -> Result<Self> { + let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?; + let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?; + let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?; + let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?; + let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?; + let cross_attn_token_to_image = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("cross_attn_token_to_image"), + )?; + let cross_attn_image_to_token = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("cross_attn_image_to_token"), + )?; + let mlp = super::MlpBlock::new( + embedding_dim, + mlp_dim, + candle_nn::Activation::Relu, + vb.pp("mlp"), + )?; + Ok(Self { + self_attn, + norm1, + cross_attn_image_to_token, + norm2, + mlp, + norm3, + norm4, + cross_attn_token_to_image, + skip_first_layer_pe, + }) + } + + fn forward( + &self, + queries: &Tensor, + keys: &Tensor, + query_pe: &Tensor, + key_pe: &Tensor, + ) -> Result<(Tensor, Tensor)> { + // Self attention block + let queries = if self.skip_first_layer_pe { + self.self_attn.forward(queries, queries, queries)? + } else { + let q = (queries + query_pe)?; + let attn_out = self.self_attn.forward(&q, &q, queries)?; + (queries + attn_out)? + }; + let queries = self.norm1.forward(&queries)?; + + // Cross attention block, tokens attending to image embedding + let q = (&queries + query_pe)?; + let k = (keys + key_pe)?; + let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?; + let queries = (&queries + attn_out)?; + let queries = self.norm2.forward(&queries)?; + + // MLP block + let mlp_out = self.mlp.forward(&queries); + let queries = (queries + mlp_out)?; + let queries = self.norm3.forward(&queries)?; + + // Cross attention block, image embedding attending to tokens + let q = (&queries + query_pe)?; + let k = (keys + key_pe)?; + let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?; + let keys = (keys + attn_out)?; + let keys = self.norm4.forward(&keys)?; + + Ok((queries, keys)) + } +} + +#[derive(Debug)] +pub struct TwoWayTransformer { + layers: Vec<TwoWayAttentionBlock>, + final_attn_token_to_image: Attention, + norm_final_attn: LayerNorm, +} + +impl TwoWayTransformer { + pub fn new( + depth: usize, + embedding_dim: usize, + num_heads: usize, + mlp_dim: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb_l = vb.pp("layers"); + let mut layers = Vec::with_capacity(depth); + for i in 0..depth { + let layer = + TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?; + layers.push(layer) + } + let final_attn_token_to_image = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("final_attn_token_to_image"), + )?; + let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?; + Ok(Self { + layers, + final_attn_token_to_image, + norm_final_attn, + }) + } + + pub fn forward( + &self, + image_embedding: &Tensor, + image_pe: &Tensor, + point_embedding: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?; + let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?; + + let mut queries = point_embedding.clone(); + let mut keys = image_embedding; + + for layer in self.layers.iter() { + (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)? + } + + let q = (&queries + point_embedding)?; + let k = (&keys + image_pe)?; + let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?; + let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?; + + Ok((queries, keys)) + } +} diff --git a/candle-transformers/src/object_detection.rs b/candle-transformers/src/object_detection.rs new file mode 100644 index 00000000..ce579316 --- /dev/null +++ b/candle-transformers/src/object_detection.rs @@ -0,0 +1,52 @@ +/// A bounding box around an object. +#[derive(Debug, Clone)] +pub struct Bbox<D> { + pub xmin: f32, + pub ymin: f32, + pub xmax: f32, + pub ymax: f32, + pub confidence: f32, + pub data: D, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct KeyPoint { + pub x: f32, + pub y: f32, + pub mask: f32, +} + +/// Intersection over union of two bounding boxes. +pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 { + let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.); + let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.); + let i_xmin = b1.xmin.max(b2.xmin); + let i_xmax = b1.xmax.min(b2.xmax); + let i_ymin = b1.ymin.max(b2.ymin); + let i_ymax = b1.ymax.min(b2.ymax); + let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.); + i_area / (b1_area + b2_area - i_area) +} + +pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) { + // Perform non-maximum suppression. + for bboxes_for_class in bboxes.iter_mut() { + bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap()); + let mut current_index = 0; + for index in 0..bboxes_for_class.len() { + let mut drop = false; + for prev_index in 0..current_index { + let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]); + if iou > threshold { + drop = true; + break; + } + } + if !drop { + bboxes_for_class.swap(current_index, index); + current_index += 1; + } + } + bboxes_for_class.truncate(current_index); + } +} |