summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/dinov2/main.rs283
-rw-r--r--candle-examples/examples/efficientnet/main.rs335
-rw-r--r--candle-examples/examples/quantized/main.rs2
-rw-r--r--candle-examples/examples/quantized/model.rs371
-rw-r--r--candle-examples/examples/segment-anything/main.rs109
-rw-r--r--candle-examples/examples/segment-anything/model_image_encoder.rs483
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs239
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs239
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs411
-rw-r--r--candle-examples/examples/segment-anything/model_tiny_vit.rs633
-rw-r--r--candle-examples/examples/segment-anything/model_transformer.rs221
-rw-r--r--candle-examples/examples/yolo-v3/main.rs2
-rw-r--r--candle-examples/examples/yolo-v8/main.rs2
13 files changed, 16 insertions, 3314 deletions
diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs
index e80c81e2..d3adb37c 100644
--- a/candle-examples/examples/dinov2/main.rs
+++ b/candle-examples/examples/dinov2/main.rs
@@ -9,285 +9,10 @@ extern crate accelerate_src;
use clap::Parser;
-use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use candle::{DType, IndexOp, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::dinov2;
-const IMG_SIZE: usize = 518;
-const PATCH_SIZE: usize = 14;
-const NUM_CLASSES: usize = 1000;
-
-fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
- if bias {
- candle_nn::linear(in_dim, out_dim, vb)
- } else {
- candle_nn::linear_no_bias(in_dim, out_dim, vb)
- }
-}
-
-#[derive(Debug)]
-struct Attention {
- qkv: Linear,
- proj: Linear,
- num_heads: usize,
- scale: f64,
-}
-
-impl Attention {
- fn new(
- vb: VarBuilder,
- dim: usize,
- num_heads: usize,
- qkv_bias: bool,
- proj_bias: bool,
- ) -> Result<Self> {
- let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
- let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
- let scale = 1. / ((dim / num_heads) as f64).sqrt();
- Ok(Self {
- qkv,
- proj,
- num_heads,
- scale,
- })
- }
-}
-
-impl Module for Attention {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let (b, n, c) = xs.dims3()?;
- let qkv = self
- .qkv
- .forward(xs)?
- .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
- .transpose(1, 2)? // 02134
- .transpose(0, 1)? // 20134
- .transpose(2, 3)?; // 20314
- let q = (qkv.i(0)? * self.scale)?;
- let k = qkv.i(1)?;
- let v = qkv.i(2)?;
- let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
- let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
- self.proj.forward(&attn)
- }
-}
-
-#[derive(Debug)]
-struct LayerScale {
- gamma: Tensor,
-}
-
-impl LayerScale {
- fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
- let gamma = vb.get(dim, "gamma")?;
- Ok(Self { gamma })
- }
-}
-
-impl Module for LayerScale {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- xs.broadcast_mul(&self.gamma)
- }
-}
-
-#[derive(Debug)]
-struct Mlp {
- fc1: Linear,
- fc2: Linear,
-}
-
-impl Mlp {
- fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
- let out_features = in_features;
- let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
- let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
- Ok(Self { fc1, fc2 })
- }
-}
-
-impl Module for Mlp {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let xs = self.fc1.forward(xs)?.gelu()?;
- self.fc2.forward(&xs)
- }
-}
-
-#[derive(Debug)]
-struct Block {
- norm1: LayerNorm,
- attn: Attention,
- ls1: LayerScale,
- norm2: LayerNorm,
- mlp: Mlp,
- ls2: LayerScale,
-}
-
-impl Block {
- fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
- let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
- let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
- let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
- let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
- let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
- let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
- Ok(Self {
- norm1,
- attn,
- ls1,
- norm2,
- mlp,
- ls2,
- })
- }
-}
-
-impl Module for Block {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let residual = xs;
- let xs = self
- .ls1
- .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
- let xs = (xs + residual)?;
- let residual = &xs;
- let xs = self
- .ls2
- .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
- xs + residual
- }
-}
-
-#[derive(Debug)]
-struct PatchEmbed {
- proj: candle_nn::Conv2d,
- patch_size: (usize, usize),
- num_patches: usize,
-}
-
-impl PatchEmbed {
- fn new(
- vb: VarBuilder,
- img_size: usize,
- patch_size: usize,
- in_chans: usize,
- embed_dim: usize,
- ) -> Result<Self> {
- let config = candle_nn::Conv2dConfig {
- stride: patch_size,
- ..Default::default()
- };
- let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
- let num_patches = (img_size / patch_size) * (img_size / patch_size);
- Ok(Self {
- proj,
- patch_size: (patch_size, patch_size),
- num_patches,
- })
- }
-}
-
-impl Module for PatchEmbed {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let (_b, _c, h, w) = xs.dims4()?;
- let (patch_h, patch_w) = self.patch_size;
- if (h % patch_h) != 0 {
- candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
- }
- if (w % patch_w) != 0 {
- candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
- }
- let xs = self.proj.forward(xs)?;
- let (b, c, h, w) = xs.dims4()?;
- // flatten embeddings.
- xs.reshape((b, c, h * w))?.transpose(1, 2)
- }
-}
-
-#[derive(Debug)]
-pub struct DinoVisionTransformer {
- patch_embed: PatchEmbed,
- cls_token: Tensor,
- pos_embed: Tensor,
- blocks: Vec<Block>,
- norm: LayerNorm,
- head: Linear,
-}
-
-impl DinoVisionTransformer {
- pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
- let patch_embed =
- PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
- let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
- let num_tokens = 1;
- let pos_embed = vb.get(
- (1, patch_embed.num_patches + num_tokens, embed_dim),
- "pos_embed",
- )?;
- let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
- let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
- let vb_b = vb.pp("blocks");
- let blocks = (0..depth)
- .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
- .collect::<Result<Vec<_>>>()?;
- Ok(Self {
- patch_embed,
- cls_token,
- pos_embed,
- blocks,
- norm,
- head,
- })
- }
-
- fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
- let npatch = xs.dim(1)? - 1;
- let n = self.pos_embed.dim(1)? - 1;
- let sqrt_n = (n as f64).sqrt();
- if npatch == n && w == h {
- return Ok(xs.clone());
- }
- let class_pos_embed = self.pos_embed.i((.., ..1))?;
- let patch_pos_embed = self.pos_embed.i((.., 1..))?;
- let dim = xs.dim(D::Minus1)?;
- let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
- let patch_pos_embed = patch_pos_embed
- .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
- .transpose(2, 3)?
- .transpose(1, 2)?;
- // This uses bicubic interpolation in the original implementation.
- let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
- let el_count = patch_pos_embed.shape().elem_count();
- let patch_pos_embed =
- patch_pos_embed
- .transpose(1, 2)?
- .transpose(2, 3)?
- .reshape((1, el_count / dim, dim))?;
- Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
- }
-
- fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
- let (_b, _nc, w, h) = xs.dims4()?;
- let xs = self.patch_embed.forward(xs)?;
- let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
- &xs + &self.interpolate_pos_encoding(&xs, w, h)?
- }
-}
-
-impl Module for DinoVisionTransformer {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let mut xs = self.prepare_tokens_with_mask(xs)?;
- for blk in self.blocks.iter() {
- xs = blk.forward(&xs)?
- }
- let xs = self.norm.forward(&xs)?;
- let xs_norm_clstoken = xs.i((.., 0))?;
- let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
- let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
- self.head.forward(&xs)
- }
-}
-
-pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
- DinoVisionTransformer::new(vb, 12, 384, 6)
-}
#[derive(Parser)]
struct Args {
#[arg(long)]
@@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
- let model = vit_small(vb)?;
+ let model = dinov2::vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs
index cbe2c90a..1e45e301 100644
--- a/candle-examples/examples/efficientnet/main.rs
+++ b/candle-examples/examples/efficientnet/main.rs
@@ -8,340 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
+use candle::{DType, IndexOp, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
use clap::{Parser, ValueEnum};
-use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn as nn;
-use nn::{Module, VarBuilder};
-
-// Based on the Python version from torchvision.
-// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
-#[derive(Debug, Clone, Copy)]
-pub struct MBConvConfig {
- expand_ratio: f64,
- kernel: usize,
- stride: usize,
- input_channels: usize,
- out_channels: usize,
- num_layers: usize,
-}
-
-fn make_divisible(v: f64, divisor: usize) -> usize {
- let min_value = divisor;
- let new_v = usize::max(
- min_value,
- (v + divisor as f64 * 0.5) as usize / divisor * divisor,
- );
- if (new_v as f64) < 0.9 * v {
- new_v + divisor
- } else {
- new_v
- }
-}
-
-fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
- let bneck_conf = |e, k, s, i, o, n| {
- let input_channels = make_divisible(i as f64 * width_mult, 8);
- let out_channels = make_divisible(o as f64 * width_mult, 8);
- let num_layers = (n as f64 * depth_mult).ceil() as usize;
- MBConvConfig {
- expand_ratio: e,
- kernel: k,
- stride: s,
- input_channels,
- out_channels,
- num_layers,
- }
- };
- vec![
- bneck_conf(1., 3, 1, 32, 16, 1),
- bneck_conf(6., 3, 2, 16, 24, 2),
- bneck_conf(6., 5, 2, 24, 40, 2),
- bneck_conf(6., 3, 2, 40, 80, 3),
- bneck_conf(6., 5, 1, 80, 112, 3),
- bneck_conf(6., 5, 2, 112, 192, 4),
- bneck_conf(6., 3, 1, 192, 320, 1),
- ]
-}
-
-impl MBConvConfig {
- fn b0() -> Vec<Self> {
- bneck_confs(1.0, 1.0)
- }
- fn b1() -> Vec<Self> {
- bneck_confs(1.0, 1.1)
- }
- fn b2() -> Vec<Self> {
- bneck_confs(1.1, 1.2)
- }
- fn b3() -> Vec<Self> {
- bneck_confs(1.2, 1.4)
- }
- fn b4() -> Vec<Self> {
- bneck_confs(1.4, 1.8)
- }
- fn b5() -> Vec<Self> {
- bneck_confs(1.6, 2.2)
- }
- fn b6() -> Vec<Self> {
- bneck_confs(1.8, 2.6)
- }
- fn b7() -> Vec<Self> {
- bneck_confs(2.0, 3.1)
- }
-}
-
-/// Conv2D with same padding.
-#[derive(Debug)]
-struct Conv2DSame {
- conv2d: nn::Conv2d,
- s: usize,
- k: usize,
-}
-
-impl Conv2DSame {
- fn new(
- vb: VarBuilder,
- i: usize,
- o: usize,
- k: usize,
- stride: usize,
- groups: usize,
- bias: bool,
- ) -> Result<Self> {
- let conv_config = nn::Conv2dConfig {
- stride,
- groups,
- ..Default::default()
- };
- let conv2d = if bias {
- nn::conv2d(i, o, k, conv_config, vb)?
- } else {
- nn::conv2d_no_bias(i, o, k, conv_config, vb)?
- };
- Ok(Self {
- conv2d,
- s: stride,
- k,
- })
- }
-}
-
-impl Module for Conv2DSame {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let s = self.s;
- let k = self.k;
- let (_, _, ih, iw) = xs.dims4()?;
- let oh = (ih + s - 1) / s;
- let ow = (iw + s - 1) / s;
- let pad_h = usize::max((oh - 1) * s + k - ih, 0);
- let pad_w = usize::max((ow - 1) * s + k - iw, 0);
- if pad_h > 0 || pad_w > 0 {
- let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
- let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
- self.conv2d.forward(&xs)
- } else {
- self.conv2d.forward(xs)
- }
- }
-}
-
-#[derive(Debug)]
-struct ConvNormActivation {
- conv2d: Conv2DSame,
- bn2d: nn::BatchNorm,
- activation: bool,
-}
-
-impl ConvNormActivation {
- fn new(
- vb: VarBuilder,
- i: usize,
- o: usize,
- k: usize,
- stride: usize,
- groups: usize,
- ) -> Result<Self> {
- let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
- let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
- Ok(Self {
- conv2d,
- bn2d,
- activation: true,
- })
- }
-
- fn no_activation(self) -> Self {
- Self {
- activation: false,
- ..self
- }
- }
-}
-
-impl Module for ConvNormActivation {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let xs = self.conv2d.forward(xs)?;
- let xs = self.bn2d.forward(&xs)?;
- if self.activation {
- swish(&xs)
- } else {
- Ok(xs)
- }
- }
-}
-
-#[derive(Debug)]
-struct SqueezeExcitation {
- fc1: Conv2DSame,
- fc2: Conv2DSame,
-}
-
-impl SqueezeExcitation {
- fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
- let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
- let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
- Ok(Self { fc1, fc2 })
- }
-}
-
-impl Module for SqueezeExcitation {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let residual = xs;
- // equivalent to adaptive_avg_pool2d([1, 1])
- let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
- let xs = self.fc1.forward(&xs)?;
- let xs = swish(&xs)?;
- let xs = self.fc2.forward(&xs)?;
- let xs = nn::ops::sigmoid(&xs)?;
- residual.broadcast_mul(&xs)
- }
-}
-
-#[derive(Debug)]
-struct MBConv {
- expand_cna: Option<ConvNormActivation>,
- depthwise_cna: ConvNormActivation,
- squeeze_excitation: SqueezeExcitation,
- project_cna: ConvNormActivation,
- config: MBConvConfig,
-}
-
-impl MBConv {
- fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
- let vb = vb.pp("block");
- let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
- let expand_cna = if exp != c.input_channels {
- Some(ConvNormActivation::new(
- vb.pp("0"),
- c.input_channels,
- exp,
- 1,
- 1,
- 1,
- )?)
- } else {
- None
- };
- let start_index = if expand_cna.is_some() { 1 } else { 0 };
- let depthwise_cna =
- ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
- let squeeze_channels = usize::max(1, c.input_channels / 4);
- let squeeze_excitation =
- SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
- let project_cna =
- ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
- .no_activation();
- Ok(Self {
- expand_cna,
- depthwise_cna,
- squeeze_excitation,
- project_cna,
- config: c,
- })
- }
-}
-
-impl Module for MBConv {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let use_res_connect =
- self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
- let ys = match &self.expand_cna {
- Some(expand_cna) => expand_cna.forward(xs)?,
- None => xs.clone(),
- };
- let ys = self.depthwise_cna.forward(&ys)?;
- let ys = self.squeeze_excitation.forward(&ys)?;
- let ys = self.project_cna.forward(&ys)?;
- if use_res_connect {
- ys + xs
- } else {
- Ok(ys)
- }
- }
-}
-
-fn swish(s: &Tensor) -> Result<Tensor> {
- s * nn::ops::sigmoid(s)?
-}
-
-#[derive(Debug)]
-struct EfficientNet {
- init_cna: ConvNormActivation,
- blocks: Vec<MBConv>,
- final_cna: ConvNormActivation,
- classifier: nn::Linear,
-}
-
-impl EfficientNet {
- fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
- let f_p = p.pp("features");
- let first_in_c = configs[0].input_channels;
- let last_out_c = configs.last().unwrap().out_channels;
- let final_out_c = 4 * last_out_c;
- let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
- let nconfigs = configs.len();
- let mut blocks = vec![];
- for (index, cnf) in configs.into_iter().enumerate() {
- let f_p = f_p.pp(index + 1);
- for r_index in 0..cnf.num_layers {
- let cnf = if r_index == 0 {
- cnf
- } else {
- MBConvConfig {
- input_channels: cnf.out_channels,
- stride: 1,
- ..cnf
- }
- };
- blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
- }
- }
- let final_cna =
- ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
- let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
- Ok(Self {
- init_cna,
- blocks,
- final_cna,
- classifier,
- })
- }
-}
-
-impl Module for EfficientNet {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let mut xs = self.init_cna.forward(xs)?;
- for block in self.blocks.iter() {
- xs = block.forward(&xs)?
- }
- let xs = self.final_cna.forward(&xs)?;
- // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
- let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
- self.classifier.forward(&xs)
- }
-}
-
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
B0,
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index a3f98d8e..c8179d33 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
-mod model;
+use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
diff --git a/candle-examples/examples/quantized/model.rs b/candle-examples/examples/quantized/model.rs
deleted file mode 100644
index da0bd0b0..00000000
--- a/candle-examples/examples/quantized/model.rs
+++ /dev/null
@@ -1,371 +0,0 @@
-use std::collections::HashMap;
-
-use candle::quantized::QTensor;
-use candle::quantized::{ggml_file, gguf_file};
-use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, Module};
-
-pub const MAX_SEQ_LEN: usize = 4096;
-
-struct RmsNorm {
- inner: candle_nn::LayerNorm,
- span: tracing::Span,
-}
-
-impl RmsNorm {
- fn new(scale: QTensor, eps: f32) -> Result<Self> {
- let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
- let scale = scale.dequantize(&Device::Cpu)?;
- let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
- Ok(Self { inner, span })
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
-// QMatMul wrapper adding some tracing.
-struct QMatMul {
- inner: candle::quantized::QMatMul,
- span: tracing::Span,
-}
-
-impl QMatMul {
- fn from_qtensor(qtensor: QTensor) -> Self {
- let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
- let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
- Self { inner, span }
- }
-
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(xs)
- }
-}
-
-struct LayerWeights {
- attention_wq: QMatMul,
- attention_wk: QMatMul,
- attention_wv: QMatMul,
- attention_wo: QMatMul,
- attention_norm: RmsNorm,
- feed_forward_w1: QMatMul,
- feed_forward_w2: QMatMul,
- feed_forward_w3: QMatMul,
- ffn_norm: RmsNorm,
- n_head: usize,
- n_kv_head: usize,
- head_dim: usize,
- cos: Tensor,
- sin: Tensor,
- kv_cache: Option<(Tensor, Tensor)>,
- span_attn: tracing::Span,
- span_rot: tracing::Span,
- span_mlp: tracing::Span,
-}
-
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
- let shape = mask.shape();
- let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
- let m = mask.where_cond(&on_true, on_false)?;
- Ok(m)
-}
-
-impl LayerWeights {
- fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let _enter = self.span_rot.enter();
- let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
- let cos = self
- .cos
- .narrow(0, index_pos, seq_len)?
- .reshape((seq_len, n_embd / 2, 1))?;
- let sin = self
- .sin
- .narrow(0, index_pos, seq_len)?
- .reshape((seq_len, n_embd / 2, 1))?;
- let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
- let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
- // This mimics the llama.cpp behavior.
- // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
- // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
- // The resulting y0 and y1 are also interleaved with:
- // y0 = x0*cos - x1*sin
- // y1 = x0*sin + x1*cos
- let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
- let x0 = x.narrow(D::Minus1, 0, 1)?;
- let x1 = x.narrow(D::Minus1, 1, 1)?;
- let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
- let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
- let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
- let rope = rope.flatten_from(D::Minus2)?;
- Ok(rope)
- }
-
- fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
- let _enter = self.span_attn.enter();
- let (b_sz, seq_len, n_embd) = x.dims3()?;
- let q = self.attention_wq.forward(x)?;
- let k = self.attention_wk.forward(x)?;
- let v = self.attention_wv.forward(x)?;
-
- let q = q
- .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
- .transpose(1, 2)?;
- let k = k
- .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
- .transpose(1, 2)?;
- let v = v
- .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
- .transpose(1, 2)?;
-
- let q = self.apply_rotary_emb(&q, index_pos)?;
- let k = self.apply_rotary_emb(&k, index_pos)?;
-
- let (k, v) = match &self.kv_cache {
- None => (k, v),
- Some((k_cache, v_cache)) => {
- if index_pos == 0 {
- (k, v)
- } else {
- let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
- let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
- (k, v)
- }
- }
- };
- self.kv_cache = Some((k.clone(), v.clone()));
-
- // Support for MQA, useful for 70B models.
- let k = self.repeat_kv(k)?;
- let v = self.repeat_kv(v)?;
-
- let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
- let mask = mask.broadcast_as(att.shape())?;
- let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
- let att = candle_nn::ops::softmax(&att, D::Minus1)?;
- // Convert to contiguous as matmul doesn't support strided vs for now.
- let y = att.matmul(&v.contiguous()?)?;
- let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
- let y = self.attention_wo.forward(&y)?;
- Ok(y)
- }
-
- fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
- let n_rep = self.n_head / self.n_kv_head;
- if n_rep == 1 {
- Ok(x)
- } else {
- let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
- let x = x
- .unsqueeze(2)?
- .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
- .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
- Ok(x)
- }
- }
-}
-
-pub struct ModelWeights {
- tok_embeddings: Embedding,
- layers: Vec<LayerWeights>,
- norm: RmsNorm,
- output: QMatMul,
- masks: HashMap<usize, Tensor>,
- span: tracing::Span,
- span_output: tracing::Span,
-}
-
-fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
- let theta: Vec<_> = (0..head_dim)
- .step_by(2)
- .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
- .collect();
- let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
- let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
- .to_dtype(DType::F32)?
- .reshape((MAX_SEQ_LEN, 1))?
- .matmul(&theta.reshape((1, theta.elem_count()))?)?;
- let cos = idx_theta.cos()?;
- let sin = idx_theta.sin()?;
- Ok((cos, sin))
-}
-
-impl ModelWeights {
- pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
- let cpu = &Device::Cpu;
- let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
- let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
- let tok_embeddings = ct.remove("tok_embeddings.weight")?;
- let tok_embeddings = tok_embeddings.dequantize(cpu)?;
- let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
- let output = ct.remove("output.weight")?;
- let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
- for layer_idx in 0..ct.hparams.n_layer {
- let prefix = format!("layers.{layer_idx}");
- let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
- let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
- let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
- let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
- let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
- let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
- let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
- let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
- let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
- let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
- let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
- let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
- layers.push(LayerWeights {
- attention_wq: QMatMul::from_qtensor(attention_wq),
- attention_wk: QMatMul::from_qtensor(attention_wk),
- attention_wv: QMatMul::from_qtensor(attention_wv),
- attention_wo: QMatMul::from_qtensor(attention_wo),
- attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
- feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
- feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
- feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
- ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
- n_head: ct.hparams.n_head as usize,
- n_kv_head: ct.hparams.n_head as usize / gqa,
- head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
- cos: cos.clone(),
- sin: sin.clone(),
- kv_cache: None,
- span_attn,
- span_rot,
- span_mlp,
- })
- }
- let span = tracing::span!(tracing::Level::TRACE, "model");
- let span_output = tracing::span!(tracing::Level::TRACE, "output");
- Ok(Self {
- tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
- layers,
- norm,
- output: QMatMul::from_qtensor(output),
- masks: HashMap::new(),
- span,
- span_output,
- })
- }
-
- pub fn from_gguf<R: std::io::Seek + std::io::Read>(
- ct: gguf_file::Content,
- reader: &mut R,
- ) -> Result<Self> {
- let cpu = &Device::Cpu;
- let md_get = |s: &str| match ct.metadata.get(s) {
- None => candle::bail!("cannot find {s} in metadata"),
- Some(v) => Ok(v),
- };
-
- // Parameter extraction from metadata.
- let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
- let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
- let block_count = md_get("llama.block_count")?.to_u32()? as usize;
- let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
- let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
- // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
- let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;
-
- let rope_freq_base = md_get("llama.rope.freq_base")
- .and_then(|m| m.to_f32())
- .unwrap_or(10000f32);
- let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
-
- let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
- let tok_embeddings = tok_embeddings.dequantize(cpu)?;
- let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
- let output = ct.tensor(reader, "output.weight")?;
- let mut layers = Vec::with_capacity(block_count);
- for layer_idx in 0..block_count {
- let prefix = format!("blk.{layer_idx}");
- let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
- let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
- let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
- let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
- let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
- let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
- let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
- let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
- let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
- let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
- let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
- let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
- layers.push(LayerWeights {
- attention_wq: QMatMul::from_qtensor(attention_wq),
- attention_wk: QMatMul::from_qtensor(attention_wk),
- attention_wv: QMatMul::from_qtensor(attention_wv),
- attention_wo: QMatMul::from_qtensor(attention_wo),
- attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
- feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
- feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
- feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
- ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
- n_head: head_count,
- n_kv_head: head_count_kv,
- head_dim: embedding_length / head_count,
- cos: cos.clone(),
- sin: sin.clone(),
- kv_cache: None,
- span_attn,
- span_rot,
- span_mlp,
- })
- }
- let span = tracing::span!(tracing::Level::TRACE, "model");
- let span_output = tracing::span!(tracing::Level::TRACE, "output");
- Ok(Self {
- tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
- layers,
- norm,
- output: QMatMul::from_qtensor(output),
- masks: HashMap::new(),
- span,
- span_output,
- })
- }
-
- fn mask(&mut self, t: usize) -> Result<Tensor> {
- if let Some(mask) = self.masks.get(&t) {
- Ok(mask.clone())
- } else {
- let mask: Vec<_> = (0..t)
- .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
- .collect();
- let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
- self.masks.insert(t, mask.clone());
- Ok(mask)
- }
- }
-
- pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let (_b_sz, seq_len) = x.dims2()?;
- let mask = self.mask(seq_len)?;
- let _enter = self.span.enter();
- let mut layer_in = self.tok_embeddings.forward(x)?;
- for layer in self.layers.iter_mut() {
- let x = layer_in;
- let residual = &x;
- let x = layer.attention_norm.forward(&x)?;
- let attn = layer.forward_attn(&x, &mask, index_pos)?;
- let x = (attn + residual)?;
-
- // MLP
- let _enter = layer.span_mlp.enter();
- let residual = &x;
- let x = layer.ffn_norm.forward(&x)?;
- let w1 = layer.feed_forward_w1.forward(&x)?;
- let w3 = layer.feed_forward_w3.forward(&x)?;
- let mlp = layer
- .feed_forward_w2
- .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?;
- layer_in = (mlp + residual)?;
- }
- let x = self.norm.forward(&layer_in)?;
- let x = x.i((.., seq_len - 1, ..))?;
- let _enter = self.span_output.enter();
- self.output.forward(&x)
- }
-}
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 9ce2f158..21ba0415 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -7,108 +7,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-pub mod model_image_encoder;
-pub mod model_mask_decoder;
-pub mod model_prompt_encoder;
-pub mod model_sam;
-pub mod model_tiny_vit;
-pub mod model_transformer;
-
-use candle::{DType, Result, Tensor};
-use candle_nn::{Module, VarBuilder};
+use candle::DType;
+use candle_nn::VarBuilder;
+use candle_transformers::models::segment_anything::sam;
use clap::Parser;
-pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
- let inner = if bias {
- candle_nn::linear(in_dim, out_dim, vb)?
- } else {
- candle_nn::linear_no_bias(in_dim, out_dim, vb)?
- };
- let span = tracing::span!(tracing::Level::TRACE, "linear");
- Ok(Linear { inner, span })
-}
-
-#[derive(Debug)]
-pub struct LayerNorm2d {
- weight: Tensor,
- bias: Tensor,
- num_channels: usize,
- eps: f64,
-}
-
-impl LayerNorm2d {
- pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
- let weight = vb.get(num_channels, "weight")?;
- let bias = vb.get(num_channels, "bias")?;
- Ok(Self {
- weight,
- bias,
- num_channels,
- eps,
- })
- }
-}
-
-impl Module for LayerNorm2d {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let u = xs.mean_keepdim(1)?;
- let xs = xs.broadcast_sub(&u)?;
- let s = xs.sqr()?.mean_keepdim(1)?;
- let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
- xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
- .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
- }
-}
-
-#[derive(Debug)]
-pub struct MlpBlock {
- lin1: Linear,
- lin2: Linear,
- activation: candle_nn::Activation,
- span: tracing::Span,
-}
-
-impl MlpBlock {
- pub fn new(
- embedding_dim: usize,
- mlp_dim: usize,
- activation: candle_nn::Activation,
- vb: VarBuilder,
- ) -> Result<Self> {
- let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
- let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
- let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
- Ok(Self {
- lin1,
- lin2,
- activation,
- span,
- })
- }
-}
-
-impl Module for MlpBlock {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- xs.apply(&self.lin1)?
- .apply(&self.activation)?
- .apply(&self.lin2)
- }
-}
-
-#[derive(Debug)]
-pub struct Linear {
- inner: candle_nn::Linear,
- span: tracing::Span,
-}
-
-impl Module for Linear {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
#[derive(Parser)]
struct Args {
#[arg(long)]
@@ -173,7 +76,7 @@ pub fn main() -> anyhow::Result<()> {
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
- let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
+ let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@@ -195,9 +98,9 @@ pub fn main() -> anyhow::Result<()> {
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = if args.use_tiny {
- model_sam::Sam::new_tiny(vb)? // tiny vit_t
+ sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
- model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
+ sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
if args.generate_masks {
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs
deleted file mode 100644
index 76cd15d0..00000000
--- a/candle-examples/examples/segment-anything/model_image_encoder.rs
+++ /dev/null
@@ -1,483 +0,0 @@
-use candle::{DType, IndexOp, Result, Tensor};
-use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
-
-#[derive(Debug)]
-struct PatchEmbed {
- proj: candle_nn::Conv2d,
- span: tracing::Span,
-}
-
-impl PatchEmbed {
- fn new(
- in_chans: usize,
- embed_dim: usize,
- k_size: usize,
- stride: usize,
- padding: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let cfg = candle_nn::Conv2dConfig {
- stride,
- padding,
- ..Default::default()
- };
- let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
- let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
- Ok(Self { proj, span })
- }
-}
-
-impl Module for PatchEmbed {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- xs.apply(&self.proj)?.permute((0, 2, 3, 1))
- }
-}
-
-// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final
-// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096
-// (attn.reshape((b, q_h, q_w, k_h, k_w))?
-// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
-// .reshape((b, q_h * q_w, k_h * k_w))
-// Ideally we would perform this operation in place but this is not supported in candle at the
-// moment. We should also investigate using f16 rather than f32.
-struct Add3(usize, usize, usize, usize, usize);
-impl candle::CustomOp3 for Add3 {
- fn name(&self) -> &'static str {
- "add3"
- }
-
- fn cpu_fwd(
- &self,
- s1: &candle::CpuStorage,
- l1: &candle::Layout,
- s2: &candle::CpuStorage,
- l2: &candle::Layout,
- s3: &candle::CpuStorage,
- l3: &candle::Layout,
- ) -> Result<(candle::CpuStorage, candle::Shape)> {
- use rayon::prelude::*;
-
- let Add3(b, q_h, q_w, k_h, k_w) = *self;
- let s1 = s1.as_slice::<f32>()?;
- let s1 = match l1.contiguous_offsets() {
- None => candle::bail!("input1 has to be contiguous"),
- Some((o1, o2)) => &s1[o1..o2],
- };
- let s2 = s2.as_slice::<f32>()?;
- let s2 = match l2.contiguous_offsets() {
- None => candle::bail!("input2 has to be contiguous"),
- Some((o1, o2)) => &s2[o1..o2],
- };
- let s3 = s3.as_slice::<f32>()?;
- let s3 = match l3.contiguous_offsets() {
- None => candle::bail!("input3 has to be contiguous"),
- Some((o1, o2)) => &s3[o1..o2],
- };
- let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
- dst.par_chunks_exact_mut(k_h * k_w)
- .enumerate()
- .for_each(|(b_idx, dst)| {
- let s1_idx = b_idx * k_h * k_w;
- let s2_idx = b_idx * k_h;
- let s3_idx = b_idx * k_w;
- for h_idx in 0..k_h {
- let s1_idx = s1_idx + h_idx * k_w;
- let s2_idx = s2_idx + h_idx;
- let dst_idx = h_idx * k_w;
- for w_idx in 0..k_w {
- let s1_idx = s1_idx + w_idx;
- let s3_idx = s3_idx + w_idx;
- let dst_idx = dst_idx + w_idx;
- dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
- }
- }
- });
- let dst = candle::WithDType::to_cpu_storage_owned(dst);
- Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
- }
-}
-
-#[derive(Debug)]
-struct Attention {
- qkv: crate::Linear,
- proj: crate::Linear,
- num_heads: usize,
- scale: f64,
- rel_pos_hw: Option<(Tensor, Tensor)>,
- span: tracing::Span,
- span_matmul: tracing::Span,
- span_rel_pos: tracing::Span,
- span_softmax: tracing::Span,
-}
-
-impl Attention {
- fn new(
- dim: usize,
- num_heads: usize,
- qkv_bias: bool,
- use_rel_pos: bool,
- input_size: (usize, usize),
- vb: VarBuilder,
- ) -> Result<Self> {
- let span = tracing::span!(tracing::Level::TRACE, "attention");
- let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
- let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
- let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
- let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
- let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
- let head_dim = dim / num_heads;
- let scale = 1. / (head_dim as f64).sqrt();
- let rel_pos_hw = if use_rel_pos {
- let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
- let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
- Some((h, w))
- } else {
- None
- };
- Ok(Self {
- qkv,
- proj,
- num_heads,
- scale,
- rel_pos_hw,
- span,
- span_matmul,
- span_rel_pos,
- span_softmax,
- })
- }
-
- fn add_decomposed_rel_pos(
- &self,
- attn: Tensor,
- q: &Tensor,
- (q_h, q_w): (usize, usize),
- (k_h, k_w): (usize, usize),
- ) -> Result<Tensor> {
- match &self.rel_pos_hw {
- Some((rel_pos_h, rel_pos_w)) => {
- let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
- let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
- let (b, _, dim) = q.dims3()?;
- let r_q = q.reshape((b, q_h, q_w, dim))?;
- // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
- let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
- // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
- let rel_w = r_q
- .transpose(1, 2)? // -> bwhc
- .contiguous()?
- .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
- .transpose(1, 2)?
- .contiguous()?;
- if attn.device().is_cpu() {
- let op = Add3(b, q_h, q_w, k_h, k_w);
- attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
- } else {
- (attn.reshape((b, q_h, q_w, k_h, k_w))?
- + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
- .reshape((b, q_h * q_w, k_h * k_w))
- }
- }
- None => Ok(attn),
- }
- }
-}
-
-fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
- let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
- let dev = rel_pos.device();
- let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
- todo!("interpolation")
- } else {
- rel_pos
- };
- let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
- .reshape((q_size, 1))?
- .to_dtype(DType::F32)?;
- let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
- .reshape((1, k_size))?
- .to_dtype(DType::F32)?;
- let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
- let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
- let relative_coords = (q_coords.broadcast_sub(&k_coords)?
- + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
- let (d1, d2) = relative_coords.dims2()?;
- let relative_coords = relative_coords.to_dtype(DType::U32)?;
- rel_pos_resized
- .index_select(&relative_coords.reshape(d1 * d2)?, 0)?
- .reshape((d1, d2, ()))
-}
-
-impl Module for Attention {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let (b, h, w, c) = xs.dims4()?;
- let qkv = self
- .qkv
- .forward(&xs.flatten_to(1)?)?
- .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
- .permute((2, 0, 3, 1, 4))?
- .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
- let q = qkv.i(0)?;
- let k = qkv.i(1)?;
- let v = qkv.i(2)?;
- let attn = {
- let _enter = self.span_matmul.enter();
- (&q * self.scale)?.matmul(&k.t()?)?
- };
- let attn = {
- let _enter = self.span_rel_pos.enter();
- self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
- };
- let attn = {
- let _enter = self.span_softmax.enter();
- candle_nn::ops::softmax_last_dim(&attn)?
- };
- let attn = {
- let _enter = self.span_matmul.enter();
- attn.matmul(&v)?
- };
- let attn = attn
- .reshape((b, self.num_heads, h, w, c / self.num_heads))?
- .permute((0, 2, 3, 1, 4))?
- .reshape((b, h * w, c))?;
- self.proj.forward(&attn)?.reshape((b, h, w, c))
- }
-}
-
-#[derive(Debug)]
-struct Block {
- norm1: LayerNorm,
- attn: Attention,
- norm2: LayerNorm,
- mlp: crate::MlpBlock,
- window_size: usize,
- span: tracing::Span,
-}
-
-impl Block {
- fn new(
- dim: usize,
- num_heads: usize,
- qkv_bias: bool,
- use_rel_pos: bool,
- window_size: usize,
- input_size: (usize, usize),
- vb: VarBuilder,
- ) -> Result<Self> {
- let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
- let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
- let input_size_attn = if window_size == 0 {
- input_size
- } else {
- (window_size, window_size)
- };
- let attn = Attention::new(
- dim,
- num_heads,
- qkv_bias,
- use_rel_pos,
- input_size_attn,
- vb.pp("attn"),
- )?;
- let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
- let span = tracing::span!(tracing::Level::TRACE, "ie-block");
- Ok(Self {
- norm1,
- attn,
- norm2,
- mlp,
- window_size,
- span,
- })
- }
-}
-
-fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
- let (b, h, w, c) = xs.dims4()?;
- let pad_h = (window_size - h % window_size) % window_size;
- let pad_w = (window_size - w % window_size) % window_size;
- let xs = if pad_h > 0 {
- xs.pad_with_zeros(1, 0, pad_h)?
- } else {
- xs
- };
- let xs = if pad_w > 0 {
- xs.pad_with_zeros(2, 0, pad_w)?
- } else {
- xs
- };
- let (h_p, w_p) = (h + pad_h, w + pad_w);
- let windows = xs
- .reshape((
- b,
- h_p / window_size,
- window_size,
- w_p / window_size,
- window_size,
- c,
- ))?
- .transpose(2, 3)?
- .contiguous()?
- .flatten_to(2)?;
- Ok((windows, (h_p, w_p)))
-}
-
-fn window_unpartition(
- windows: Tensor,
- window_size: usize,
- (h_p, w_p): (usize, usize),
- (h, w): (usize, usize),
-) -> Result<Tensor> {
- let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
- let xs = windows
- .reshape((
- b,
- h_p / window_size,
- w_p / window_size,
- window_size,
- window_size,
- windows.elem_count() / b / h_p / w_p,
- ))?
- .transpose(2, 3)?
- .contiguous()?
- .reshape((b, h_p, w_p, ()))?;
- let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
- let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
- Ok(xs)
-}
-
-impl Module for Block {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let shortcut = xs;
- let xs = self.norm1.forward(xs)?;
- let hw = (xs.dim(1)?, xs.dim(2)?);
- let (xs, pad_hw) = if self.window_size > 0 {
- window_partition(xs, self.window_size)?
- } else {
- (xs, (0, 0))
- };
- let xs = self.attn.forward(&xs)?;
- let xs = if self.window_size > 0 {
- window_unpartition(xs, self.window_size, pad_hw, hw)?
- } else {
- xs
- };
- let xs = (xs + shortcut)?;
- &xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
- }
-}
-
-#[derive(Debug)]
-pub struct ImageEncoderViT {
- patch_embed: PatchEmbed,
- blocks: Vec<Block>,
- neck_conv1: candle_nn::Conv2d,
- neck_ln1: crate::LayerNorm2d,
- neck_conv2: candle_nn::Conv2d,
- neck_ln2: crate::LayerNorm2d,
- pos_embed: Option<Tensor>,
- span: tracing::Span,
-}
-
-impl ImageEncoderViT {
- #[allow(clippy::too_many_arguments)]
- pub fn new(
- img_size: usize,
- patch_size: usize,
- in_chans: usize,
- embed_dim: usize,
- depth: usize,
- num_heads: usize,
- out_chans: usize,
- qkv_bias: bool,
- use_rel_pos: bool,
- use_abs_pos: bool,
- window_size: usize,
- global_attn_indexes: &[usize],
- vb: VarBuilder,
- ) -> Result<Self> {
- let patch_embed = PatchEmbed::new(
- in_chans,
- embed_dim,
- patch_size,
- patch_size,
- 0,
- vb.pp("patch_embed"),
- )?;
- let mut blocks = Vec::with_capacity(depth);
- let vb_b = vb.pp("blocks");
- for i in 0..depth {
- let window_size = if global_attn_indexes.contains(&i) {
- 0
- } else {
- window_size
- };
- let block = Block::new(
- embed_dim,
- num_heads,
- qkv_bias,
- use_rel_pos,
- window_size,
- (img_size / patch_size, img_size / patch_size),
- vb_b.pp(i),
- )?;
- blocks.push(block)
- }
- let neck_conv1 = candle_nn::conv2d_no_bias(
- embed_dim,
- out_chans,
- 1,
- Default::default(),
- vb.pp("neck.0"),
- )?;
- let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
- let cfg = candle_nn::Conv2dConfig {
- padding: 1,
- ..Default::default()
- };
- let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
- let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
- let pos_embed = if use_abs_pos {
- let p = vb.get(
- (1, img_size / patch_size, img_size / patch_size, embed_dim),
- "pos_embed",
- )?;
- Some(p)
- } else {
- None
- };
- let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
- Ok(Self {
- patch_embed,
- blocks,
- neck_conv1,
- neck_ln1,
- neck_conv2,
- neck_ln2,
- pos_embed,
- span,
- })
- }
-}
-
-impl Module for ImageEncoderViT {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let xs = self.patch_embed.forward(xs)?;
- let mut xs = match &self.pos_embed {
- Some(pos_embed) => (xs + pos_embed)?,
- None => xs,
- };
- for block in self.blocks.iter() {
- xs = block.forward(&xs)?
- }
- xs.permute((0, 3, 1, 2))?
- .apply(&self.neck_conv1)?
- .apply(&self.neck_ln1)?
- .apply(&self.neck_conv2)?
- .apply(&self.neck_ln2)
- }
-}
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
deleted file mode 100644
index c02b44a7..00000000
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ /dev/null
@@ -1,239 +0,0 @@
-use candle::{IndexOp, Result, Tensor};
-use candle_nn::{Module, VarBuilder};
-
-use crate::model_transformer::TwoWayTransformer;
-
-#[derive(Debug)]
-struct MlpMaskDecoder {
- layers: Vec<crate::Linear>,
- sigmoid_output: bool,
- span: tracing::Span,
-}
-
-impl MlpMaskDecoder {
- fn new(
- input_dim: usize,
- hidden_dim: usize,
- output_dim: usize,
- num_layers: usize,
- sigmoid_output: bool,
- vb: VarBuilder,
- ) -> Result<Self> {
- let mut layers = Vec::with_capacity(num_layers);
- let vb = vb.pp("layers");
- for i in 0..num_layers {
- let in_dim = if i == 0 { input_dim } else { hidden_dim };
- let out_dim = if i + 1 == num_layers {
- output_dim
- } else {
- hidden_dim
- };
- let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
- layers.push(layer)
- }
- let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
- Ok(Self {
- layers,
- sigmoid_output,
- span,
- })
- }
-}
-
-impl Module for MlpMaskDecoder {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let mut xs = xs.clone();
- for (i, layer) in self.layers.iter().enumerate() {
- xs = layer.forward(&xs)?;
- if i + 1 < self.layers.len() {
- xs = xs.relu()?
- }
- }
- if self.sigmoid_output {
- candle_nn::ops::sigmoid(&xs)
- } else {
- Ok(xs)
- }
- }
-}
-
-#[derive(Debug)]
-pub struct MaskDecoder {
- iou_token: candle_nn::Embedding,
- mask_tokens: candle_nn::Embedding,
- iou_prediction_head: MlpMaskDecoder,
- output_upscaling_conv1: candle_nn::ConvTranspose2d,
- output_upscaling_ln: crate::LayerNorm2d,
- output_upscaling_conv2: candle_nn::ConvTranspose2d,
- num_mask_tokens: usize,
- output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
- transformer: TwoWayTransformer,
- span: tracing::Span,
-}
-
-impl MaskDecoder {
- pub fn new(
- transformer_dim: usize,
- num_multimask_outputs: usize,
- iou_head_depth: usize,
- iou_head_hidden_dim: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let num_mask_tokens = num_multimask_outputs + 1;
- let iou_prediction_head = MlpMaskDecoder::new(
- transformer_dim,
- iou_head_hidden_dim,
- num_mask_tokens,
- iou_head_depth,
- false,
- vb.pp("iou_prediction_head"),
- )?;
- let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
- let mask_tokens =
- candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
- let cfg = candle_nn::ConvTranspose2dConfig {
- stride: 2,
- ..Default::default()
- };
- let output_upscaling_conv1 = candle_nn::conv_transpose2d(
- transformer_dim,
- transformer_dim / 4,
- 2,
- cfg,
- vb.pp("output_upscaling.0"),
- )?;
- let output_upscaling_ln =
- crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
- let output_upscaling_conv2 = candle_nn::conv_transpose2d(
- transformer_dim / 4,
- transformer_dim / 8,
- 2,
- cfg,
- vb.pp("output_upscaling.3"),
- )?;
- let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
- let vb_o = vb.pp("output_hypernetworks_mlps");
- for i in 0..num_mask_tokens {
- let mlp = MlpMaskDecoder::new(
- transformer_dim,
- transformer_dim,
- transformer_dim / 8,
- 3,
- false,
- vb_o.pp(i),
- )?;
- output_hypernetworks_mlps.push(mlp)
- }
- let transformer = TwoWayTransformer::new(
- /* depth */ 2,
- /* embedding_dim */ transformer_dim,
- /* num_heads */ 8,
- /* mlp_dim */ 2048,
- vb.pp("transformer"),
- )?;
- let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
- Ok(Self {
- iou_token,
- mask_tokens,
- iou_prediction_head,
- output_upscaling_conv1,
- output_upscaling_ln,
- output_upscaling_conv2,
- num_mask_tokens,
- output_hypernetworks_mlps,
- transformer,
- span,
- })
- }
-
- pub fn forward(
- &self,
- image_embeddings: &Tensor,
- image_pe: &Tensor,
- sparse_prompt_embeddings: &Tensor,
- dense_prompt_embeddings: &Tensor,
- multimask_output: bool,
- ) -> Result<(Tensor, Tensor)> {
- let _enter = self.span.enter();
- let (masks, iou_pred) = self.predict_masks(
- image_embeddings,
- image_pe,
- sparse_prompt_embeddings,
- dense_prompt_embeddings,
- )?;
- let masks = if multimask_output {
- masks.i((.., 1..))?
- } else {
- masks.i((.., 0..1))?
- };
- let iou_pred = if multimask_output {
- iou_pred.i((.., 1..))?
- } else {
- iou_pred.i((.., 0..1))?
- };
- Ok((masks, iou_pred))
- }
-
- fn predict_masks(
- &self,
- image_embeddings: &Tensor,
- image_pe: &Tensor,
- sparse_prompt_embeddings: &Tensor,
- dense_prompt_embeddings: &Tensor,
- ) -> Result<(Tensor, Tensor)> {
- // Concatenate ouput tokens.
- let output_tokens = Tensor::cat(
- &[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
- 0,
- )?;
- let (d1, d2) = output_tokens.dims2()?;
- let output_tokens =
- output_tokens
- .unsqueeze(0)?
- .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
- let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
-
- // Expand per-image data in batch direction to be per mask
- let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
- let src = src.broadcast_add(dense_prompt_embeddings)?;
- let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
- let (b, c, h, w) = src.dims4()?;
-
- // Run the transformer
- let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
- let iou_token_out = hs.i((.., 0))?;
- let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
-
- // Upscale mask embeddings and predict masks using the masks tokens.
- let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
- let upscaled_embedding = self
- .output_upscaling_conv1
- .forward(&src)?
- .apply(&self.output_upscaling_ln)?
- .gelu()?
- .apply(&self.output_upscaling_conv2)?
- .gelu()?;
- let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
- for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
- let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
- hyper_in_list.push(h)
- }
- let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
- let (b, c, h, w) = upscaled_embedding.dims4()?;
- let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
- let masks = masks.reshape((b, (), h, w))?;
-
- // Generate mask quality predictions.
- let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
- Ok((masks, iou_pred))
- }
-}
-
-// Equivalent to torch.repeat_interleave
-fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
- let img = img.unsqueeze(dim + 1)?;
- let mut dims = img.dims().to_vec();
- dims[dim + 1] = repeats;
- img.broadcast_as(dims)?.flatten(dim, dim + 1)
-}
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
deleted file mode 100644
index 7bbe8419..00000000
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ /dev/null
@@ -1,239 +0,0 @@
-use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::VarBuilder;
-
-#[derive(Debug)]
-struct PostionEmbeddingRandom {
- positional_encoding_gaussian_matrix: Tensor,
-}
-
-impl PostionEmbeddingRandom {
- fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
- let positional_encoding_gaussian_matrix =
- vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
- Ok(Self {
- positional_encoding_gaussian_matrix,
- })
- }
-
- fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
- let coords = coords.affine(2., -1.)?;
- let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
- let coords = (coords * (2. * std::f64::consts::PI))?;
- Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
- }
-
- fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
- let device = self.positional_encoding_gaussian_matrix.device();
- let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
- let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
- let x_embed = (x_embed / w as f64)?
- .reshape((1, ()))?
- .broadcast_as((h, w))?;
- let y_embed = (y_embed / h as f64)?
- .reshape(((), 1))?
- .broadcast_as((h, w))?;
- let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
- self.pe_encoding(&coords)?.permute((2, 0, 1))
- }
-
- fn forward_with_coords(
- &self,
- coords_input: &Tensor,
- image_size: (usize, usize),
- ) -> Result<Tensor> {
- let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
- let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
- let c = coords_input.dim(D::Minus1)?;
- let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
- let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
- self.pe_encoding(&coords)
- }
-}
-
-#[derive(Debug)]
-pub struct PromptEncoder {
- pe_layer: PostionEmbeddingRandom,
- point_embeddings: Vec<candle_nn::Embedding>,
- not_a_point_embed: candle_nn::Embedding,
- mask_downscaling_conv1: candle_nn::Conv2d,
- mask_downscaling_ln1: crate::LayerNorm2d,
- mask_downscaling_conv2: candle_nn::Conv2d,
- mask_downscaling_ln2: crate::LayerNorm2d,
- mask_downscaling_conv3: candle_nn::Conv2d,
- no_mask_embed: candle_nn::Embedding,
- image_embedding_size: (usize, usize),
- input_image_size: (usize, usize),
- embed_dim: usize,
- span: tracing::Span,
-}
-
-impl PromptEncoder {
- pub fn new(
- embed_dim: usize,
- image_embedding_size: (usize, usize),
- input_image_size: (usize, usize),
- mask_in_chans: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let num_points_embeddings = 4;
- let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
- let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
- let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
- let cfg = candle_nn::Conv2dConfig {
- stride: 2,
- ..Default::default()
- };
- let mask_downscaling_conv1 =
- candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
- let mask_downscaling_conv2 = candle_nn::conv2d(
- mask_in_chans / 4,
- mask_in_chans,
- 2,
- cfg,
- vb.pp("mask_downscaling.3"),
- )?;
- let mask_downscaling_conv3 = candle_nn::conv2d(
- mask_in_chans,
- embed_dim,
- 1,
- Default::default(),
- vb.pp("mask_downscaling.6"),
- )?;
- let mask_downscaling_ln1 =
- crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
- let mask_downscaling_ln2 =
- crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
- let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
- let vb_e = vb.pp("point_embeddings");
- for i in 0..num_points_embeddings {
- let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
- point_embeddings.push(emb)
- }
- let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
- Ok(Self {
- pe_layer,
- point_embeddings,
- not_a_point_embed,
- mask_downscaling_conv1,
- mask_downscaling_ln1,
- mask_downscaling_conv2,
- mask_downscaling_ln2,
- mask_downscaling_conv3,
- no_mask_embed,
- image_embedding_size,
- input_image_size,
- embed_dim,
- span,
- })
- }
-
- pub fn get_dense_pe(&self) -> Result<Tensor> {
- self.pe_layer
- .forward(self.image_embedding_size.0, self.image_embedding_size.1)?
- .unsqueeze(0)
- }
-
- fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
- masks
- .apply(&self.mask_downscaling_conv1)?
- .apply(&self.mask_downscaling_ln1)?
- .gelu()?
- .apply(&self.mask_downscaling_conv2)?
- .apply(&self.mask_downscaling_ln2)?
- .gelu()?
- .apply(&self.mask_downscaling_conv3)
- }
-
- fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
- let points = (points + 0.5)?;
- let dev = points.device();
- let (points, labels) = if pad {
- let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
- let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
- let points = Tensor::cat(&[&points, &padding_point], 1)?;
- let labels = Tensor::cat(&[labels, &padding_label], 1)?;
- (points, labels)
- } else {
- (points, labels.clone())
- };
- let point_embedding = self
- .pe_layer
- .forward_with_coords(&points, self.input_image_size)?;
- let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
- let zeros = point_embedding.zeros_like()?;
- let point_embedding = labels.lt(0f32)?.where_cond(
- &self
- .not_a_point_embed
- .embeddings()
- .broadcast_as(zeros.shape())?,
- &point_embedding,
- )?;
- let labels0 = labels.eq(0f32)?.where_cond(
- &self.point_embeddings[0]
- .embeddings()
- .broadcast_as(zeros.shape())?,
- &zeros,
- )?;
- let point_embedding = (point_embedding + labels0)?;
- let labels1 = labels.eq(1f32)?.where_cond(
- &self.point_embeddings[1]
- .embeddings()
- .broadcast_as(zeros.shape())?,
- &zeros,
- )?;
- let point_embedding = (point_embedding + labels1)?;
- Ok(point_embedding)
- }
-
- fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
- let boxes = (boxes + 0.5)?;
- let coords = boxes.reshape(((), 2, 2))?;
- let corner_embedding = self
- .pe_layer
- .forward_with_coords(&coords, self.input_image_size)?;
- let ce1 = corner_embedding.i((.., 0))?;
- let ce2 = corner_embedding.i((.., 1))?;
- let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
- let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
- Tensor::cat(&[&ce1, &ce2], 1)
- }
-
- pub fn forward(
- &self,
- points: Option<(&Tensor, &Tensor)>,
- boxes: Option<&Tensor>,
- masks: Option<&Tensor>,
- ) -> Result<(Tensor, Tensor)> {
- let _enter = self.span.enter();
- let se_points = match points {
- Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
- None => None,
- };
- let se_boxes = match boxes {
- Some(boxes) => Some(self.embed_boxes(boxes)?),
- None => None,
- };
- let sparse_embeddings = match (se_points, se_boxes) {
- (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
- (Some(se_points), None) => se_points,
- (None, Some(se_boxes)) => se_boxes,
- (None, None) => {
- Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
- }
- };
-
- let dense_embeddings = match masks {
- None => {
- let emb = self.no_mask_embed.embeddings();
- emb.reshape((1, (), 1, 1))?.expand((
- 1,
- emb.elem_count(),
- self.image_embedding_size.0,
- self.image_embedding_size.1,
- ))?
- }
- Some(masks) => self.embed_masks(masks)?,
- };
- Ok((sparse_embeddings, dense_embeddings))
- }
-}
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
deleted file mode 100644
index b1a81af6..00000000
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ /dev/null
@@ -1,411 +0,0 @@
-use candle::{DType, IndexOp, Result, Tensor};
-use candle_nn::{Module, VarBuilder};
-
-use crate::model_image_encoder::ImageEncoderViT;
-use crate::model_mask_decoder::MaskDecoder;
-use crate::model_prompt_encoder::PromptEncoder;
-use crate::model_tiny_vit::{tiny_vit_5m, TinyViT};
-
-const PROMPT_EMBED_DIM: usize = 256;
-pub const IMAGE_SIZE: usize = 1024;
-const VIT_PATCH_SIZE: usize = 16;
-const PRED_IOU_THRESH: f32 = 0.88;
-const STABILITY_SCORE_OFFSET: f32 = 1.0;
-const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
-const MODEL_MASK_THRESHOLD: f32 = 0.0;
-const CROP_NMS_THRESH: f32 = 0.7;
-
-#[derive(Debug)]
-enum ImageEncoder {
- Original(ImageEncoderViT),
- TinyViT(TinyViT),
-}
-
-impl Module for ImageEncoder {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- match self {
- Self::Original(vit) => vit.forward(xs),
- Self::TinyViT(vit) => vit.forward(xs),
- }
- }
-}
-
-#[derive(Debug)]
-pub struct Sam {
- image_encoder: ImageEncoder,
- prompt_encoder: PromptEncoder,
- mask_decoder: MaskDecoder,
- pixel_mean: Tensor,
- pixel_std: Tensor,
-}
-
-impl Sam {
- pub fn new(
- encoder_embed_dim: usize,
- encoder_depth: usize,
- encoder_num_heads: usize,
- encoder_global_attn_indexes: &[usize],
- vb: VarBuilder,
- ) -> Result<Self> {
- let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
-
- let image_encoder = ImageEncoderViT::new(
- IMAGE_SIZE,
- VIT_PATCH_SIZE,
- 3,
- encoder_embed_dim,
- encoder_depth,
- encoder_num_heads,
- PROMPT_EMBED_DIM,
- /* qkv_bias */ true,
- /* use_rel_pos */ true,
- /* use_abs_pos */ true,
- /* window_size */ 14,
- /* global_attn_indexes */ encoder_global_attn_indexes,
- vb.pp("image_encoder"),
- )?;
- let prompt_encoder = PromptEncoder::new(
- PROMPT_EMBED_DIM,
- (image_embedding_size, image_embedding_size),
- (IMAGE_SIZE, IMAGE_SIZE),
- 16,
- vb.pp("prompt_encoder"),
- )?;
- let mask_decoder = MaskDecoder::new(
- PROMPT_EMBED_DIM,
- /* num_multitask_outputs */ 3,
- /* iou_head_depth */ 3,
- /* iou_head_hidden_dim */ 256,
- vb.pp("mask_decoder"),
- )?;
- let pixel_mean =
- Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
- let pixel_std =
- Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
- Ok(Self {
- image_encoder: ImageEncoder::Original(image_encoder),
- prompt_encoder,
- mask_decoder,
- pixel_std,
- pixel_mean,
- })
- }
-
- pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
- let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
-
- let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
- let prompt_encoder = PromptEncoder::new(
- PROMPT_EMBED_DIM,
- (image_embedding_size, image_embedding_size),
- (IMAGE_SIZE, IMAGE_SIZE),
- 16,
- vb.pp("prompt_encoder"),
- )?;
- let mask_decoder = MaskDecoder::new(
- PROMPT_EMBED_DIM,
- /* num_multitask_outputs */ 3,
- /* iou_head_depth */ 3,
- /* iou_head_hidden_dim */ 256,
- vb.pp("mask_decoder"),
- )?;
- let pixel_mean =
- Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
- let pixel_std =
- Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
- Ok(Self {
- image_encoder: ImageEncoder::TinyViT(image_encoder),
- prompt_encoder,
- mask_decoder,
- pixel_std,
- pixel_mean,
- })
- }
-
- pub fn forward(
- &self,
- img: &Tensor,
- point: Option<(f64, f64)>,
- multimask_output: bool,
- ) -> Result<(Tensor, Tensor)> {
- let (_c, original_h, original_w) = img.dims3()?;
- let img = self.preprocess(img)?.unsqueeze(0)?;
- let img_embeddings = self.image_encoder.forward(&img)?;
- let image_pe = self.prompt_encoder.get_dense_pe()?;
- let points = match point {
- None => None,
- Some((x, y)) => {
- let points = Tensor::new(
- &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
- img.device(),
- )?;
- let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
- Some((points, labels))
- }
- };
- let points = points.as_ref().map(|(x, y)| (x, y));
- let (sparse_prompt_embeddings, dense_prompt_embeddings) =
- self.prompt_encoder.forward(points, None, None)?;
- let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
- &img_embeddings,
- &image_pe,
- &sparse_prompt_embeddings,
- &dense_prompt_embeddings,
- multimask_output,
- )?;
- let mask = low_res_mask
- .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
- .get(0)?
- .i((.., ..original_h, ..original_w))?;
- Ok((mask, iou_predictions))
- }
-
- pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
- let img = img
- .broadcast_mul(&self.pixel_std)?
- .broadcast_add(&self.pixel_mean)?;
- img.maximum(&img.zeros_like()?)?
- .minimum(&(img.ones_like()? * 255.)?)
- }
-
- pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
- let (_c, h, w) = img.dims3()?;
- let img = img
- .to_dtype(DType::F32)?
- .broadcast_sub(&self.pixel_mean)?
- .broadcast_div(&self.pixel_std)?;
- if h > IMAGE_SIZE || w > IMAGE_SIZE {
- candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
- }
- let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
- img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
- }
-
- fn process_crop(
- &self,
- img: &Tensor,
- cb: CropBox,
- point_grids: &[(f64, f64)],
- ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
- // Crop the image and calculate embeddings.
- let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
- let img = self.preprocess(&img)?.unsqueeze(0)?;
- let img_embeddings = self.image_encoder.forward(&img)?;
-
- let crop_w = cb.x1 - cb.x0;
- let crop_h = cb.y1 - cb.y0;
-
- // Generate masks for this crop.
- let image_pe = self.prompt_encoder.get_dense_pe()?;
- let points = point_grids
- .iter()
- .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
- .collect::<Vec<_>>();
-
- let mut bboxes = Vec::new();
- for points in points.chunks(64) {
- // Run the model on this batch.
- let points_len = points.len();
- let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
- let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
- let (sparse_prompt_embeddings, dense_prompt_embeddings) =
- self.prompt_encoder
- .forward(Some((&in_points, &in_labels)), None, None)?;
-
- let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
- &img_embeddings,
- &image_pe,
- &sparse_prompt_embeddings,
- &dense_prompt_embeddings,
- /* multimask_output */ true,
- )?;
- let low_res_mask = low_res_mask.flatten(0, 1)?;
- let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
- let dev = low_res_mask.device();
-
- for (i, iou) in iou_predictions.iter().enumerate() {
- // Filter by predicted IoU.
- if *iou < PRED_IOU_THRESH {
- continue;
- }
- let low_res_mask = low_res_mask.get(i)?;
-
- // Calculate stability score.
- let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
- .broadcast_as(low_res_mask.shape())?;
- let intersections = low_res_mask
- .ge(&bound)?
- .to_dtype(DType::F32)?
- .sum_all()?
- .to_vec0::<f32>()?;
- let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
- .broadcast_as(low_res_mask.shape())?;
- let unions = low_res_mask
- .ge(&bound)?
- .to_dtype(DType::F32)?
- .sum_all()?
- .to_vec0::<f32>()?;
- let stability_score = intersections / unions;
- if stability_score < STABILITY_SCORE_THRESHOLD {
- continue;
- }
-
- // Threshold masks and calculate boxes.
- let low_res_mask = low_res_mask
- .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
- .to_dtype(DType::U32)?;
- let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
- let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
- let min_max_x = min_max_indexes(&low_res_mask_per_x);
- let min_max_y = min_max_indexes(&low_res_mask_per_y);
- if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
- let bbox = candle_examples::object_detection::Bbox {
- xmin: x0 as f32,
- ymin: y0 as f32,
- xmax: x1 as f32,
- ymax: y1 as f32,
- confidence: *iou,
- data: low_res_mask,
- };
- bboxes.push(bbox);
- }
- // TODO:
- // Filter boxes that touch crop boundaries
- // Compress to RLE.
- }
- }
-
- let mut bboxes = vec![bboxes];
- // Remove duplicates within this crop.
- candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
-
- // TODO: Return to the original image frame.
- Ok(bboxes.remove(0))
- }
-
- pub fn generate_masks(
- &self,
- img: &Tensor,
- points_per_side: usize,
- crop_n_layer: usize,
- crop_overlap_ratio: f64,
- crop_n_points_downscale_factor: usize,
- ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
- let (_c, h, w) = img.dims3()?;
- let point_grids = build_all_layer_point_grids(
- points_per_side,
- crop_n_layer,
- crop_n_points_downscale_factor,
- );
- let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
- let mut bboxes = Vec::new();
- for crop_box in crop_boxes.into_iter() {
- let layer_idx = crop_box.layer_idx;
- let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
- bboxes.extend(b)
- }
- // TODO: remove duplicates
- Ok(bboxes)
- }
-}
-
-// Return the first and last indexes i for which values[i] > 0
-fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
- let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
- for (i, &s) in values.iter().enumerate() {
- if s == 0 {
- continue;
- }
- min_i = usize::min(i, min_i);
- max_i = usize::max(i, max_i);
- }
- if max_i < min_i {
- None
- } else {
- Some((min_i, max_i))
- }
-}
-
-#[derive(Debug)]
-struct CropBox {
- x0: usize,
- y0: usize,
- x1: usize,
- y1: usize,
- layer_idx: usize,
-}
-
-impl CropBox {
- fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
- Self {
- x0,
- y0,
- x1,
- y1,
- layer_idx,
- }
- }
-}
-
-fn generate_crop_boxes(
- (im_h, im_w): (usize, usize),
- n_layers: usize,
- overlap_ratio: f64,
-) -> Vec<CropBox> {
- fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
- f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
- }
-
- let short_side = usize::min(im_h, im_w);
-
- let mut crop_boxes = Vec::new();
-
- // Original image.
- crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
-
- for layer_idx in 1..=n_layers {
- let n_crops_per_side = 1 << layer_idx;
- let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
- let crop_w = crop_len(im_w, n_crops_per_side, overlap);
- let crop_h = crop_len(im_w, n_crops_per_side, overlap);
-
- for i_x in 0..n_crops_per_side {
- let x0 = (crop_w - overlap) * i_x;
- for i_y in 0..n_crops_per_side {
- let y0 = (crop_h - overlap) * i_y;
- let x1 = usize::min(im_w, x0 + crop_w);
- let y1 = usize::min(im_h, y0 + crop_h);
- crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
- }
- }
- }
-
- crop_boxes
-}
-
-// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
-fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
- let offset = 1f64 / (2 * n_per_side) as f64;
- let mut points = Vec::with_capacity(n_per_side * n_per_side);
- for i_x in 0..n_per_side {
- let x = offset + i_x as f64 / n_per_side as f64;
- for i_y in 0..n_per_side {
- let y = offset + i_y as f64 / n_per_side as f64;
- points.push((x, y))
- }
- }
- points
-}
-
-fn build_all_layer_point_grids(
- n_per_side: usize,
- n_layers: usize,
- scale_per_layer: usize,
-) -> Vec<Vec<(f64, f64)>> {
- let mut points_by_layer = Vec::with_capacity(n_layers + 1);
- for i in 0..=n_layers {
- let n_points = n_per_side / scale_per_layer.pow(i as u32);
- points_by_layer.push(build_point_grid(n_points))
- }
- points_by_layer
-}
diff --git a/candle-examples/examples/segment-anything/model_tiny_vit.rs b/candle-examples/examples/segment-anything/model_tiny_vit.rs
deleted file mode 100644
index ff076773..00000000
--- a/candle-examples/examples/segment-anything/model_tiny_vit.rs
+++ /dev/null
@@ -1,633 +0,0 @@
-// Adapted from:
-// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
-use candle::{IndexOp, Result, Tensor, D};
-use candle_nn::{Conv2dConfig, Module, VarBuilder};
-
-const MBCONV_EXPAND_RATIO: usize = 4;
-const MLP_RATIO: usize = 4;
-const LOCAL_CONV_SIZE: usize = 3;
-const IMG_SIZE: usize = 1024;
-const IN_CHANNELS: usize = 3;
-
-#[derive(Debug)]
-struct Conv2dBN {
- c: candle_nn::Conv2d,
- bn: candle_nn::BatchNorm,
- span: tracing::Span,
-}
-
-impl Conv2dBN {
- fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
- let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
- let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
- let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
- Ok(Self { c, bn, span })
- }
-}
-
-impl Module for Conv2dBN {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- xs.apply(&self.c)?.apply(&self.bn)
- }
-}
-
-#[derive(Debug)]
-struct PatchEmbed {
- conv1: Conv2dBN,
- conv2: Conv2dBN,
- span: tracing::Span,
-}
-
-impl PatchEmbed {
- fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
- let cfg = candle_nn::Conv2dConfig {
- stride: 2,
- padding: 1,
- ..Default::default()
- };
- let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
- let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
- let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
- Ok(Self { conv1, conv2, span })
- }
-}
-
-impl Module for PatchEmbed {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
- }
-}
-
-#[derive(Debug)]
-struct MBConv {
- conv1: Conv2dBN,
- conv2: Conv2dBN,
- conv3: Conv2dBN,
- span: tracing::Span,
-}
-
-impl MBConv {
- fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> {
- let hidden = in_ * expand_ratio;
- let cfg2 = candle_nn::Conv2dConfig {
- padding: 1,
- groups: hidden,
- ..Default::default()
- };
- let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
- let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
- let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
- let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
- Ok(Self {
- conv1,
- conv2,
- conv3,
- span,
- })
- }
-}
-
-impl Module for MBConv {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let shortcut = xs;
- let xs = xs
- .apply(&self.conv1)?
- .gelu()?
- .apply(&self.conv2)?
- .gelu()?
- .apply(&self.conv3)?;
- (xs + shortcut)?.gelu()
- }
-}
-
-#[derive(Debug)]
-struct PatchMerging {
- conv1: Conv2dBN,
- conv2: Conv2dBN,
- conv3: Conv2dBN,
- input_resolution: (usize, usize),
- span: tracing::Span,
-}
-
-impl PatchMerging {
- fn new(
- input_resolution: (usize, usize),
- dim: usize,
- out: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 };
- let cfg2 = candle_nn::Conv2dConfig {
- padding: 1,
- stride,
- groups: out,
- ..Default::default()
- };
- let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
- let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
- let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
- let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
- Ok(Self {
- conv1,
- conv2,
- conv3,
- input_resolution,
- span,
- })
- }
-}
-
-impl Module for PatchMerging {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let xs = if xs.rank() == 3 {
- let (h, w) = self.input_resolution;
- let b = xs.dim(0)?;
- xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))?
- } else {
- xs.clone()
- };
- xs.apply(&self.conv1)?
- .gelu()?
- .apply(&self.conv2)?
- .gelu()?
- .apply(&self.conv3)?
- .flatten_from(2)?
- .transpose(1, 2)
- }
-}
-
-#[derive(Debug)]
-struct ConvLayer {
- blocks: Vec<MBConv>,
- downsample: Option<PatchMerging>,
- span: tracing::Span,
-}
-
-impl ConvLayer {
- fn new(
- dim: usize,
- out: usize,
- input_resolution: (usize, usize),
- depth: usize,
- downsample: bool,
- conv_expand_ratio: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let vb_b = vb.pp("blocks");
- let mut blocks = Vec::with_capacity(depth);
- for index in 0..depth {
- let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?;
- blocks.push(block)
- }
- let downsample = if downsample {
- let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
- Some(downsample)
- } else {
- None
- };
- let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
- Ok(Self {
- blocks,
- downsample,
- span,
- })
- }
-}
-
-impl Module for ConvLayer {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let mut xs = xs.clone();
- for block in self.blocks.iter() {
- xs = block.forward(&xs)?
- }
- match &self.downsample {
- None => Ok(xs),
- Some(downsample) => downsample.forward(&xs),
- }
- }
-}
-
-#[derive(Debug)]
-struct Mlp {
- norm: candle_nn::LayerNorm,
- fc1: crate::Linear,
- fc2: crate::Linear,
- span: tracing::Span,
-}
-
-impl Mlp {
- fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
- let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
- let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?;
- let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?;
- let span = tracing::span!(tracing::Level::TRACE, "mlp");
- Ok(Self {
- norm,
- fc1,
- fc2,
- span,
- })
- }
-}
-
-impl Module for Mlp {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- xs.apply(&self.norm)?
- .apply(&self.fc1)?
- .gelu()?
- .apply(&self.fc2)
- }
-}
-
-#[derive(Debug)]
-struct Attention {
- norm: candle_nn::LayerNorm,
- qkv: crate::Linear,
- proj: crate::Linear,
- ab: Tensor,
- key_dim: usize,
- num_heads: usize,
- d: usize,
- dh: usize,
- scale: f64,
- span: tracing::Span,
- span_matmul: tracing::Span,
- span_softmax: tracing::Span,
-}
-
-impl Attention {
- fn new(
- dim: usize,
- key_dim: usize,
- num_heads: usize,
- attn_ratio: usize,
- resolution: (usize, usize),
- vb: VarBuilder,
- ) -> Result<Self> {
- let d = attn_ratio * key_dim;
- let dh = d * num_heads;
- let nh_kd = key_dim * num_heads;
- let h = dh + nh_kd * 2;
- let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
- let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?;
- let proj = crate::linear(vb.pp("proj"), dh, dim, true)?;
-
- let points = (0..resolution.0)
- .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
- .collect::<Vec<_>>();
- let mut idxs = Vec::with_capacity(points.len() * points.len());
- let mut attention_offsets = std::collections::HashMap::new();
- for &(x1, y1) in points.iter() {
- for &(x2, y2) in points.iter() {
- let offset = ((x2 - x1).abs(), (y2 - y1).abs());
- let l = attention_offsets.len();
- let idx = attention_offsets.entry(offset).or_insert(l);
- idxs.push(*idx as u32)
- }
- }
- let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
- let idxs = Tensor::new(idxs, attention_biases.device())?;
- let ab =
- attention_biases
- .index_select(&idxs, 1)?
- .reshape(((), points.len(), points.len()))?;
- let span = tracing::span!(tracing::Level::TRACE, "attention");
- let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
- let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
- Ok(Self {
- norm,
- qkv,
- proj,
- ab,
- key_dim,
- num_heads,
- d,
- dh,
- scale: 1f64 / (key_dim as f64).sqrt(),
- span,
- span_matmul,
- span_softmax,
- })
- }
-}
-
-impl Module for Attention {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let (b, n, _) = xs.dims3()?;
- let xs = xs.apply(&self.norm)?;
- let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
- let q = qkv
- .narrow(D::Minus1, 0, self.key_dim)?
- .permute((0, 2, 1, 3))?
- .contiguous()?;
- let k = qkv
- .narrow(D::Minus1, self.key_dim, self.key_dim)?
- .permute((0, 2, 1, 3))?
- .contiguous()?;
- let v = qkv
- .narrow(D::Minus1, 2 * self.key_dim, self.d)?
- .permute((0, 2, 1, 3))?
- .contiguous()?;
- let attn = {
- let _enter = self.span_matmul.enter();
- (q.matmul(&k.t()?)? * self.scale)?
- };
- let attn = attn.broadcast_add(&self.ab)?;
- let attn = {
- let _enter = self.span_softmax.enter();
- candle_nn::ops::softmax_last_dim(&attn)?
- };
- let attn = {
- let _enter = self.span_matmul.enter();
- attn.matmul(&v)?
- };
- attn.transpose(1, 2)?
- .reshape((b, n, self.dh))?
- .apply(&self.proj)
- }
-}
-
-#[derive(Debug)]
-struct TinyViTBlock {
- attn: Attention,
- local_conv: Conv2dBN,
- mlp: Mlp,
- window_size: usize,
- input_resolution: (usize, usize),
- span: tracing::Span,
-}
-
-impl TinyViTBlock {
- fn new(
- dim: usize,
- input_resolution: (usize, usize),
- num_heads: usize,
- window_size: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let head_dim = dim / num_heads;
- let attn = Attention::new(
- dim,
- head_dim,
- num_heads,
- 1,
- (window_size, window_size),
- vb.pp("attn"),
- )?;
- let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
- let cfg = candle_nn::Conv2dConfig {
- padding: LOCAL_CONV_SIZE / 2,
- groups: dim,
- ..Default::default()
- };
- let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
- let span = tracing::span!(tracing::Level::TRACE, "attention");
- Ok(Self {
- attn,
- local_conv,
- mlp,
- window_size,
- input_resolution,
- span,
- })
- }
-}
-
-impl Module for TinyViTBlock {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let (h, w) = self.input_resolution;
- let (b, l, c) = xs.dims3()?;
- let res_x = xs;
- let xs = if h == self.window_size && w == self.window_size {
- self.attn.forward(xs)?
- } else {
- let xs = xs.reshape((b, h, w, c))?;
- let pad_b = (self.window_size - h % self.window_size) % self.window_size;
- let pad_r = (self.window_size - w % self.window_size) % self.window_size;
-
- let xs = if pad_b > 0 {
- xs.pad_with_zeros(1, 0, pad_b)?
- } else {
- xs
- };
- let xs = if pad_r > 0 {
- xs.pad_with_zeros(2, 0, pad_r)?
- } else {
- xs
- };
- let (p_h, p_w) = (h + pad_b, w + pad_r);
- let n_h = p_h / self.window_size;
- let n_w = p_w / self.window_size;
- let xs = xs
- .reshape((b, n_h, self.window_size, n_w, self.window_size, c))?
- .transpose(2, 3)?
- .reshape((b * n_h * n_w, self.window_size * self.window_size, c))?;
- let xs = self.attn.forward(&xs)?;
- let xs = xs
- .reshape((b, n_h, n_w, self.window_size, self.window_size, c))?
- .transpose(2, 3)?
- .reshape((b, p_h, p_w, c))?;
- let xs = if pad_r > 0 {
- xs.i((.., .., ..w))?.contiguous()?
- } else {
- xs
- };
- let xs = if pad_b > 0 {
- xs.i((.., ..h, ..))?.contiguous()?
- } else {
- xs
- };
- xs.reshape((b, l, c))?
- };
- let xs = (xs + res_x)?;
- let xs = xs
- .transpose(1, 2)?
- .reshape((b, c, h, w))?
- .apply(&self.local_conv)?
- .reshape((b, c, l))?
- .transpose(1, 2)?;
- &xs + self.mlp.forward(&xs)?
- }
-}
-
-#[derive(Debug)]
-struct BasicLayer {
- blocks: Vec<TinyViTBlock>,
- downsample: Option<PatchMerging>,
- span: tracing::Span,
-}
-
-impl BasicLayer {
- #[allow(clippy::too_many_arguments)]
- fn new(
- dim: usize,
- input_resolution: (usize, usize),
- depth: usize,
- num_heads: usize,
- window_size: usize,
- downsample: bool,
- out: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let vb_b = vb.pp("blocks");
- let mut blocks = Vec::with_capacity(depth);
- for index in 0..depth {
- let block = TinyViTBlock::new(
- dim,
- input_resolution,
- num_heads,
- window_size,
- vb_b.pp(index),
- )?;
- blocks.push(block)
- }
- let downsample = if downsample {
- let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
- Some(downsample)
- } else {
- None
- };
- let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
- Ok(Self {
- blocks,
- downsample,
- span,
- })
- }
-}
-
-impl Module for BasicLayer {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let mut xs = xs.clone();
- for block in self.blocks.iter() {
- xs = block.forward(&xs)?
- }
- match &self.downsample {
- None => Ok(xs),
- Some(downsample) => downsample.forward(&xs),
- }
- }
-}
-
-#[derive(Debug)]
-pub struct TinyViT {
- patch_embed: PatchEmbed,
- layer0: ConvLayer,
- layers: Vec<BasicLayer>,
- // norm_head: candle_nn::LayerNorm,
- // head: candle_nn::Linear,
- neck_conv1: candle_nn::Conv2d,
- neck_ln1: crate::LayerNorm2d,
- neck_conv2: candle_nn::Conv2d,
- neck_ln2: crate::LayerNorm2d,
- span: tracing::Span,
- span_neck: tracing::Span,
-}
-
-impl TinyViT {
- pub fn new(
- embed_dims: &[usize],
- depths: &[usize],
- num_heads: &[usize],
- window_sizes: &[usize],
- _num_classes: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
- let patches_resolution = IMG_SIZE / 4;
-
- let vb_l = vb.pp("layers");
- let layer0 = ConvLayer::new(
- /* dim */ embed_dims[0],
- /* out */ embed_dims[1],
- /* input_resolution */ (patches_resolution, patches_resolution),
- /* depth */ depths[0],
- /* downsample */ true,
- /* conv_expand_ratio */ MBCONV_EXPAND_RATIO,
- vb_l.pp(0),
- )?;
-
- let num_layers = embed_dims.len();
- let mut layers = Vec::with_capacity(num_layers - 1);
- for i_layer in 1..num_layers {
- let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2));
- let layer = BasicLayer::new(
- /* dim */ embed_dims[i_layer],
- /* input_resolution */ (patches_resolution, patches_resolution),
- /* depth */ depths[i_layer],
- /* num_heads */ num_heads[i_layer],
- /* window_size */ window_sizes[i_layer],
- /* downsample */ i_layer < num_layers - 1,
- /* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)],
- vb_l.pp(i_layer),
- )?;
- layers.push(layer)
- }
-
- let last_embed_dim = embed_dims[embed_dims.len() - 1];
- // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
- // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
- let neck_conv1 =
- candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
- let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
- let cfg = candle_nn::Conv2dConfig {
- padding: 1,
- ..Default::default()
- };
- let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
- let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
-
- let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
- let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
- Ok(Self {
- patch_embed,
- layer0,
- layers,
- neck_conv1,
- neck_ln1,
- neck_conv2,
- neck_ln2,
- span,
- span_neck,
- })
- }
-}
-
-impl Module for TinyViT {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let xs = self.patch_embed.forward(xs)?;
- let mut xs = self.layer0.forward(&xs)?;
- for layer in self.layers.iter() {
- xs = layer.forward(&xs)?
- }
- let (b, _, c) = xs.dims3()?;
- let _enter = self.span_neck.enter();
- xs.reshape((b, 64, 64, c))?
- .permute((0, 3, 1, 2))?
- .apply(&self.neck_conv1)?
- .apply(&self.neck_ln1)?
- .apply(&self.neck_conv2)?
- .apply(&self.neck_ln2)
- }
-}
-
-pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
- TinyViT::new(
- /* embed_dims */ &[64, 128, 160, 320],
- /* depths */ &[2, 2, 6, 2],
- /* num_heads */ &[2, 4, 5, 10],
- /* window_sizes */ &[7, 7, 14, 7],
- /* num_classes */ 1000,
- vb,
- )
-}
diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs
deleted file mode 100644
index e12aac08..00000000
--- a/candle-examples/examples/segment-anything/model_transformer.rs
+++ /dev/null
@@ -1,221 +0,0 @@
-use candle::{Result, Tensor};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
-
-#[derive(Debug)]
-struct Attention {
- q_proj: Linear,
- k_proj: Linear,
- v_proj: Linear,
- out_proj: Linear,
- num_heads: usize,
-}
-
-impl Attention {
- fn new(
- embedding_dim: usize,
- num_heads: usize,
- downsample_rate: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let internal_dim = embedding_dim / downsample_rate;
- let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
- let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
- let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
- let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
- Ok(Self {
- q_proj,
- k_proj,
- v_proj,
- out_proj,
- num_heads,
- })
- }
-
- fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
- let (b, n, c) = x.dims3()?;
- x.reshape((b, n, self.num_heads, c / self.num_heads))?
- .transpose(1, 2)?
- .contiguous()
- }
-
- fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
- let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
- x.transpose(1, 2)?
- .reshape((b, n_tokens, n_heads * c_per_head))
- }
-
- fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
- let q = self.q_proj.forward(&q.contiguous()?)?;
- let k = self.k_proj.forward(&k.contiguous()?)?;
- let v = self.v_proj.forward(&v.contiguous()?)?;
-
- let q = self.separate_heads(&q)?;
- let k = self.separate_heads(&k)?;
- let v = self.separate_heads(&v)?;
-
- let (_, _, _, c_per_head) = q.dims4()?;
- let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
- let attn = candle_nn::ops::softmax_last_dim(&attn)?;
-
- let out = attn.matmul(&v)?;
- self.recombine_heads(&out)?.apply(&self.out_proj)
- }
-}
-
-#[derive(Debug)]
-struct TwoWayAttentionBlock {
- self_attn: Attention,
- norm1: LayerNorm,
- cross_attn_token_to_image: Attention,
- norm2: LayerNorm,
- mlp: crate::MlpBlock,
- norm3: LayerNorm,
- norm4: LayerNorm,
- cross_attn_image_to_token: Attention,
- skip_first_layer_pe: bool,
-}
-
-impl TwoWayAttentionBlock {
- fn new(
- embedding_dim: usize,
- num_heads: usize,
- mlp_dim: usize,
- skip_first_layer_pe: bool,
- vb: VarBuilder,
- ) -> Result<Self> {
- let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
- let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
- let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
- let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
- let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
- let cross_attn_token_to_image = Attention::new(
- embedding_dim,
- num_heads,
- 2,
- vb.pp("cross_attn_token_to_image"),
- )?;
- let cross_attn_image_to_token = Attention::new(
- embedding_dim,
- num_heads,
- 2,
- vb.pp("cross_attn_image_to_token"),
- )?;
- let mlp = crate::MlpBlock::new(
- embedding_dim,
- mlp_dim,
- candle_nn::Activation::Relu,
- vb.pp("mlp"),
- )?;
- Ok(Self {
- self_attn,
- norm1,
- cross_attn_image_to_token,
- norm2,
- mlp,
- norm3,
- norm4,
- cross_attn_token_to_image,
- skip_first_layer_pe,
- })
- }
-
- fn forward(
- &self,
- queries: &Tensor,
- keys: &Tensor,
- query_pe: &Tensor,
- key_pe: &Tensor,
- ) -> Result<(Tensor, Tensor)> {
- // Self attention block
- let queries = if self.skip_first_layer_pe {
- self.self_attn.forward(queries, queries, queries)?
- } else {
- let q = (queries + query_pe)?;
- let attn_out = self.self_attn.forward(&q, &q, queries)?;
- (queries + attn_out)?
- };
- let queries = self.norm1.forward(&queries)?;
-
- // Cross attention block, tokens attending to image embedding
- let q = (&queries + query_pe)?;
- let k = (keys + key_pe)?;
- let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
- let queries = (&queries + attn_out)?;
- let queries = self.norm2.forward(&queries)?;
-
- // MLP block
- let mlp_out = self.mlp.forward(&queries);
- let queries = (queries + mlp_out)?;
- let queries = self.norm3.forward(&queries)?;
-
- // Cross attention block, image embedding attending to tokens
- let q = (&queries + query_pe)?;
- let k = (keys + key_pe)?;
- let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
- let keys = (keys + attn_out)?;
- let keys = self.norm4.forward(&keys)?;
-
- Ok((queries, keys))
- }
-}
-
-#[derive(Debug)]
-pub struct TwoWayTransformer {
- layers: Vec<TwoWayAttentionBlock>,
- final_attn_token_to_image: Attention,
- norm_final_attn: LayerNorm,
-}
-
-impl TwoWayTransformer {
- pub fn new(
- depth: usize,
- embedding_dim: usize,
- num_heads: usize,
- mlp_dim: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let vb_l = vb.pp("layers");
- let mut layers = Vec::with_capacity(depth);
- for i in 0..depth {
- let layer =
- TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
- layers.push(layer)
- }
- let final_attn_token_to_image = Attention::new(
- embedding_dim,
- num_heads,
- 2,
- vb.pp("final_attn_token_to_image"),
- )?;
- let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
- Ok(Self {
- layers,
- final_attn_token_to_image,
- norm_final_attn,
- })
- }
-
- pub fn forward(
- &self,
- image_embedding: &Tensor,
- image_pe: &Tensor,
- point_embedding: &Tensor,
- ) -> Result<(Tensor, Tensor)> {
- let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
- let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
-
- let mut queries = point_embedding.clone();
- let mut keys = image_embedding;
-
- for layer in self.layers.iter() {
- (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
- }
-
- let q = (&queries + point_embedding)?;
- let k = (&keys + image_pe)?;
- let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
- let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
-
- Ok((queries, keys))
- }
-}
diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs
index 20021b45..ecf75bdf 100644
--- a/candle-examples/examples/yolo-v3/main.rs
+++ b/candle-examples/examples/yolo-v3/main.rs
@@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use candle_examples::object_detection::{non_maximum_suppression, Bbox};
+use candle_transformers::object_detection::{non_maximum_suppression, Bbox};
mod darknet;
use anyhow::Result;
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index 2017b5be..d48bac35 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -8,8 +8,8 @@ mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, IndexOp, Result, Tensor};
-use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use candle_nn::{Module, VarBuilder};
+use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
use image::DynamicImage;