summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_tiny_vit.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_tiny_vit.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_tiny_vit.rs633
1 files changed, 0 insertions, 633 deletions
diff --git a/candle-examples/examples/segment-anything/model_tiny_vit.rs b/candle-examples/examples/segment-anything/model_tiny_vit.rs
deleted file mode 100644
index ff076773..00000000
--- a/candle-examples/examples/segment-anything/model_tiny_vit.rs
+++ /dev/null
@@ -1,633 +0,0 @@
-// Adapted from:
-// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
-use candle::{IndexOp, Result, Tensor, D};
-use candle_nn::{Conv2dConfig, Module, VarBuilder};
-
-const MBCONV_EXPAND_RATIO: usize = 4;
-const MLP_RATIO: usize = 4;
-const LOCAL_CONV_SIZE: usize = 3;
-const IMG_SIZE: usize = 1024;
-const IN_CHANNELS: usize = 3;
-
-#[derive(Debug)]
-struct Conv2dBN {
- c: candle_nn::Conv2d,
- bn: candle_nn::BatchNorm,
- span: tracing::Span,
-}
-
-impl Conv2dBN {
- fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
- let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
- let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
- let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
- Ok(Self { c, bn, span })
- }
-}
-
-impl Module for Conv2dBN {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- xs.apply(&self.c)?.apply(&self.bn)
- }
-}
-
-#[derive(Debug)]
-struct PatchEmbed {
- conv1: Conv2dBN,
- conv2: Conv2dBN,
- span: tracing::Span,
-}
-
-impl PatchEmbed {
- fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
- let cfg = candle_nn::Conv2dConfig {
- stride: 2,
- padding: 1,
- ..Default::default()
- };
- let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
- let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
- let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
- Ok(Self { conv1, conv2, span })
- }
-}
-
-impl Module for PatchEmbed {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
- }
-}
-
-#[derive(Debug)]
-struct MBConv {
- conv1: Conv2dBN,
- conv2: Conv2dBN,
- conv3: Conv2dBN,
- span: tracing::Span,
-}
-
-impl MBConv {
- fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> {
- let hidden = in_ * expand_ratio;
- let cfg2 = candle_nn::Conv2dConfig {
- padding: 1,
- groups: hidden,
- ..Default::default()
- };
- let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
- let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
- let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
- let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
- Ok(Self {
- conv1,
- conv2,
- conv3,
- span,
- })
- }
-}
-
-impl Module for MBConv {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let shortcut = xs;
- let xs = xs
- .apply(&self.conv1)?
- .gelu()?
- .apply(&self.conv2)?
- .gelu()?
- .apply(&self.conv3)?;
- (xs + shortcut)?.gelu()
- }
-}
-
-#[derive(Debug)]
-struct PatchMerging {
- conv1: Conv2dBN,
- conv2: Conv2dBN,
- conv3: Conv2dBN,
- input_resolution: (usize, usize),
- span: tracing::Span,
-}
-
-impl PatchMerging {
- fn new(
- input_resolution: (usize, usize),
- dim: usize,
- out: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 };
- let cfg2 = candle_nn::Conv2dConfig {
- padding: 1,
- stride,
- groups: out,
- ..Default::default()
- };
- let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
- let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
- let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
- let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
- Ok(Self {
- conv1,
- conv2,
- conv3,
- input_resolution,
- span,
- })
- }
-}
-
-impl Module for PatchMerging {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let xs = if xs.rank() == 3 {
- let (h, w) = self.input_resolution;
- let b = xs.dim(0)?;
- xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))?
- } else {
- xs.clone()
- };
- xs.apply(&self.conv1)?
- .gelu()?
- .apply(&self.conv2)?
- .gelu()?
- .apply(&self.conv3)?
- .flatten_from(2)?
- .transpose(1, 2)
- }
-}
-
-#[derive(Debug)]
-struct ConvLayer {
- blocks: Vec<MBConv>,
- downsample: Option<PatchMerging>,
- span: tracing::Span,
-}
-
-impl ConvLayer {
- fn new(
- dim: usize,
- out: usize,
- input_resolution: (usize, usize),
- depth: usize,
- downsample: bool,
- conv_expand_ratio: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let vb_b = vb.pp("blocks");
- let mut blocks = Vec::with_capacity(depth);
- for index in 0..depth {
- let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?;
- blocks.push(block)
- }
- let downsample = if downsample {
- let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
- Some(downsample)
- } else {
- None
- };
- let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
- Ok(Self {
- blocks,
- downsample,
- span,
- })
- }
-}
-
-impl Module for ConvLayer {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let mut xs = xs.clone();
- for block in self.blocks.iter() {
- xs = block.forward(&xs)?
- }
- match &self.downsample {
- None => Ok(xs),
- Some(downsample) => downsample.forward(&xs),
- }
- }
-}
-
-#[derive(Debug)]
-struct Mlp {
- norm: candle_nn::LayerNorm,
- fc1: crate::Linear,
- fc2: crate::Linear,
- span: tracing::Span,
-}
-
-impl Mlp {
- fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
- let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
- let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?;
- let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?;
- let span = tracing::span!(tracing::Level::TRACE, "mlp");
- Ok(Self {
- norm,
- fc1,
- fc2,
- span,
- })
- }
-}
-
-impl Module for Mlp {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- xs.apply(&self.norm)?
- .apply(&self.fc1)?
- .gelu()?
- .apply(&self.fc2)
- }
-}
-
-#[derive(Debug)]
-struct Attention {
- norm: candle_nn::LayerNorm,
- qkv: crate::Linear,
- proj: crate::Linear,
- ab: Tensor,
- key_dim: usize,
- num_heads: usize,
- d: usize,
- dh: usize,
- scale: f64,
- span: tracing::Span,
- span_matmul: tracing::Span,
- span_softmax: tracing::Span,
-}
-
-impl Attention {
- fn new(
- dim: usize,
- key_dim: usize,
- num_heads: usize,
- attn_ratio: usize,
- resolution: (usize, usize),
- vb: VarBuilder,
- ) -> Result<Self> {
- let d = attn_ratio * key_dim;
- let dh = d * num_heads;
- let nh_kd = key_dim * num_heads;
- let h = dh + nh_kd * 2;
- let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
- let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?;
- let proj = crate::linear(vb.pp("proj"), dh, dim, true)?;
-
- let points = (0..resolution.0)
- .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
- .collect::<Vec<_>>();
- let mut idxs = Vec::with_capacity(points.len() * points.len());
- let mut attention_offsets = std::collections::HashMap::new();
- for &(x1, y1) in points.iter() {
- for &(x2, y2) in points.iter() {
- let offset = ((x2 - x1).abs(), (y2 - y1).abs());
- let l = attention_offsets.len();
- let idx = attention_offsets.entry(offset).or_insert(l);
- idxs.push(*idx as u32)
- }
- }
- let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
- let idxs = Tensor::new(idxs, attention_biases.device())?;
- let ab =
- attention_biases
- .index_select(&idxs, 1)?
- .reshape(((), points.len(), points.len()))?;
- let span = tracing::span!(tracing::Level::TRACE, "attention");
- let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
- let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
- Ok(Self {
- norm,
- qkv,
- proj,
- ab,
- key_dim,
- num_heads,
- d,
- dh,
- scale: 1f64 / (key_dim as f64).sqrt(),
- span,
- span_matmul,
- span_softmax,
- })
- }
-}
-
-impl Module for Attention {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let (b, n, _) = xs.dims3()?;
- let xs = xs.apply(&self.norm)?;
- let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
- let q = qkv
- .narrow(D::Minus1, 0, self.key_dim)?
- .permute((0, 2, 1, 3))?
- .contiguous()?;
- let k = qkv
- .narrow(D::Minus1, self.key_dim, self.key_dim)?
- .permute((0, 2, 1, 3))?
- .contiguous()?;
- let v = qkv
- .narrow(D::Minus1, 2 * self.key_dim, self.d)?
- .permute((0, 2, 1, 3))?
- .contiguous()?;
- let attn = {
- let _enter = self.span_matmul.enter();
- (q.matmul(&k.t()?)? * self.scale)?
- };
- let attn = attn.broadcast_add(&self.ab)?;
- let attn = {
- let _enter = self.span_softmax.enter();
- candle_nn::ops::softmax_last_dim(&attn)?
- };
- let attn = {
- let _enter = self.span_matmul.enter();
- attn.matmul(&v)?
- };
- attn.transpose(1, 2)?
- .reshape((b, n, self.dh))?
- .apply(&self.proj)
- }
-}
-
-#[derive(Debug)]
-struct TinyViTBlock {
- attn: Attention,
- local_conv: Conv2dBN,
- mlp: Mlp,
- window_size: usize,
- input_resolution: (usize, usize),
- span: tracing::Span,
-}
-
-impl TinyViTBlock {
- fn new(
- dim: usize,
- input_resolution: (usize, usize),
- num_heads: usize,
- window_size: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let head_dim = dim / num_heads;
- let attn = Attention::new(
- dim,
- head_dim,
- num_heads,
- 1,
- (window_size, window_size),
- vb.pp("attn"),
- )?;
- let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
- let cfg = candle_nn::Conv2dConfig {
- padding: LOCAL_CONV_SIZE / 2,
- groups: dim,
- ..Default::default()
- };
- let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
- let span = tracing::span!(tracing::Level::TRACE, "attention");
- Ok(Self {
- attn,
- local_conv,
- mlp,
- window_size,
- input_resolution,
- span,
- })
- }
-}
-
-impl Module for TinyViTBlock {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let (h, w) = self.input_resolution;
- let (b, l, c) = xs.dims3()?;
- let res_x = xs;
- let xs = if h == self.window_size && w == self.window_size {
- self.attn.forward(xs)?
- } else {
- let xs = xs.reshape((b, h, w, c))?;
- let pad_b = (self.window_size - h % self.window_size) % self.window_size;
- let pad_r = (self.window_size - w % self.window_size) % self.window_size;
-
- let xs = if pad_b > 0 {
- xs.pad_with_zeros(1, 0, pad_b)?
- } else {
- xs
- };
- let xs = if pad_r > 0 {
- xs.pad_with_zeros(2, 0, pad_r)?
- } else {
- xs
- };
- let (p_h, p_w) = (h + pad_b, w + pad_r);
- let n_h = p_h / self.window_size;
- let n_w = p_w / self.window_size;
- let xs = xs
- .reshape((b, n_h, self.window_size, n_w, self.window_size, c))?
- .transpose(2, 3)?
- .reshape((b * n_h * n_w, self.window_size * self.window_size, c))?;
- let xs = self.attn.forward(&xs)?;
- let xs = xs
- .reshape((b, n_h, n_w, self.window_size, self.window_size, c))?
- .transpose(2, 3)?
- .reshape((b, p_h, p_w, c))?;
- let xs = if pad_r > 0 {
- xs.i((.., .., ..w))?.contiguous()?
- } else {
- xs
- };
- let xs = if pad_b > 0 {
- xs.i((.., ..h, ..))?.contiguous()?
- } else {
- xs
- };
- xs.reshape((b, l, c))?
- };
- let xs = (xs + res_x)?;
- let xs = xs
- .transpose(1, 2)?
- .reshape((b, c, h, w))?
- .apply(&self.local_conv)?
- .reshape((b, c, l))?
- .transpose(1, 2)?;
- &xs + self.mlp.forward(&xs)?
- }
-}
-
-#[derive(Debug)]
-struct BasicLayer {
- blocks: Vec<TinyViTBlock>,
- downsample: Option<PatchMerging>,
- span: tracing::Span,
-}
-
-impl BasicLayer {
- #[allow(clippy::too_many_arguments)]
- fn new(
- dim: usize,
- input_resolution: (usize, usize),
- depth: usize,
- num_heads: usize,
- window_size: usize,
- downsample: bool,
- out: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let vb_b = vb.pp("blocks");
- let mut blocks = Vec::with_capacity(depth);
- for index in 0..depth {
- let block = TinyViTBlock::new(
- dim,
- input_resolution,
- num_heads,
- window_size,
- vb_b.pp(index),
- )?;
- blocks.push(block)
- }
- let downsample = if downsample {
- let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
- Some(downsample)
- } else {
- None
- };
- let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
- Ok(Self {
- blocks,
- downsample,
- span,
- })
- }
-}
-
-impl Module for BasicLayer {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let mut xs = xs.clone();
- for block in self.blocks.iter() {
- xs = block.forward(&xs)?
- }
- match &self.downsample {
- None => Ok(xs),
- Some(downsample) => downsample.forward(&xs),
- }
- }
-}
-
-#[derive(Debug)]
-pub struct TinyViT {
- patch_embed: PatchEmbed,
- layer0: ConvLayer,
- layers: Vec<BasicLayer>,
- // norm_head: candle_nn::LayerNorm,
- // head: candle_nn::Linear,
- neck_conv1: candle_nn::Conv2d,
- neck_ln1: crate::LayerNorm2d,
- neck_conv2: candle_nn::Conv2d,
- neck_ln2: crate::LayerNorm2d,
- span: tracing::Span,
- span_neck: tracing::Span,
-}
-
-impl TinyViT {
- pub fn new(
- embed_dims: &[usize],
- depths: &[usize],
- num_heads: &[usize],
- window_sizes: &[usize],
- _num_classes: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
- let patches_resolution = IMG_SIZE / 4;
-
- let vb_l = vb.pp("layers");
- let layer0 = ConvLayer::new(
- /* dim */ embed_dims[0],
- /* out */ embed_dims[1],
- /* input_resolution */ (patches_resolution, patches_resolution),
- /* depth */ depths[0],
- /* downsample */ true,
- /* conv_expand_ratio */ MBCONV_EXPAND_RATIO,
- vb_l.pp(0),
- )?;
-
- let num_layers = embed_dims.len();
- let mut layers = Vec::with_capacity(num_layers - 1);
- for i_layer in 1..num_layers {
- let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2));
- let layer = BasicLayer::new(
- /* dim */ embed_dims[i_layer],
- /* input_resolution */ (patches_resolution, patches_resolution),
- /* depth */ depths[i_layer],
- /* num_heads */ num_heads[i_layer],
- /* window_size */ window_sizes[i_layer],
- /* downsample */ i_layer < num_layers - 1,
- /* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)],
- vb_l.pp(i_layer),
- )?;
- layers.push(layer)
- }
-
- let last_embed_dim = embed_dims[embed_dims.len() - 1];
- // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
- // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
- let neck_conv1 =
- candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
- let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
- let cfg = candle_nn::Conv2dConfig {
- padding: 1,
- ..Default::default()
- };
- let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
- let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
-
- let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
- let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
- Ok(Self {
- patch_embed,
- layer0,
- layers,
- neck_conv1,
- neck_ln1,
- neck_conv2,
- neck_ln2,
- span,
- span_neck,
- })
- }
-}
-
-impl Module for TinyViT {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- let xs = self.patch_embed.forward(xs)?;
- let mut xs = self.layer0.forward(&xs)?;
- for layer in self.layers.iter() {
- xs = layer.forward(&xs)?
- }
- let (b, _, c) = xs.dims3()?;
- let _enter = self.span_neck.enter();
- xs.reshape((b, 64, 64, c))?
- .permute((0, 3, 1, 2))?
- .apply(&self.neck_conv1)?
- .apply(&self.neck_ln1)?
- .apply(&self.neck_conv2)?
- .apply(&self.neck_ln2)
- }
-}
-
-pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
- TinyViT::new(
- /* embed_dims */ &[64, 128, 160, 320],
- /* depths */ &[2, 2, 6, 2],
- /* num_heads */ &[2, 4, 5, 10],
- /* window_sizes */ &[7, 7, 14, 7],
- /* num_classes */ 1000,
- vb,
- )
-}