diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/dinov2/main.rs | 283 | ||||
-rw-r--r-- | candle-examples/examples/efficientnet/main.rs | 335 | ||||
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/quantized/model.rs | 371 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 109 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_image_encoder.rs | 483 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_mask_decoder.rs | 239 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_prompt_encoder.rs | 239 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_sam.rs | 411 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_tiny_vit.rs | 633 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_transformer.rs | 221 | ||||
-rw-r--r-- | candle-examples/examples/yolo-v3/main.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/yolo-v8/main.rs | 2 |
13 files changed, 16 insertions, 3314 deletions
diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs index e80c81e2..d3adb37c 100644 --- a/candle-examples/examples/dinov2/main.rs +++ b/candle-examples/examples/dinov2/main.rs @@ -9,285 +9,10 @@ extern crate accelerate_src; use clap::Parser; -use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::dinov2; -const IMG_SIZE: usize = 518; -const PATCH_SIZE: usize = 14; -const NUM_CLASSES: usize = 1000; - -fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { - if bias { - candle_nn::linear(in_dim, out_dim, vb) - } else { - candle_nn::linear_no_bias(in_dim, out_dim, vb) - } -} - -#[derive(Debug)] -struct Attention { - qkv: Linear, - proj: Linear, - num_heads: usize, - scale: f64, -} - -impl Attention { - fn new( - vb: VarBuilder, - dim: usize, - num_heads: usize, - qkv_bias: bool, - proj_bias: bool, - ) -> Result<Self> { - let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; - let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; - let scale = 1. / ((dim / num_heads) as f64).sqrt(); - Ok(Self { - qkv, - proj, - num_heads, - scale, - }) - } -} - -impl Module for Attention { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let (b, n, c) = xs.dims3()?; - let qkv = self - .qkv - .forward(xs)? - .reshape((b, n, 3, self.num_heads, c / self.num_heads))? - .transpose(1, 2)? // 02134 - .transpose(0, 1)? // 20134 - .transpose(2, 3)?; // 20314 - let q = (qkv.i(0)? * self.scale)?; - let k = qkv.i(1)?; - let v = qkv.i(2)?; - let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; - let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; - self.proj.forward(&attn) - } -} - -#[derive(Debug)] -struct LayerScale { - gamma: Tensor, -} - -impl LayerScale { - fn new(vb: VarBuilder, dim: usize) -> Result<Self> { - let gamma = vb.get(dim, "gamma")?; - Ok(Self { gamma }) - } -} - -impl Module for LayerScale { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - xs.broadcast_mul(&self.gamma) - } -} - -#[derive(Debug)] -struct Mlp { - fc1: Linear, - fc2: Linear, -} - -impl Mlp { - fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> { - let out_features = in_features; - let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?; - let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?; - Ok(Self { fc1, fc2 }) - } -} - -impl Module for Mlp { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let xs = self.fc1.forward(xs)?.gelu()?; - self.fc2.forward(&xs) - } -} - -#[derive(Debug)] -struct Block { - norm1: LayerNorm, - attn: Attention, - ls1: LayerScale, - norm2: LayerNorm, - mlp: Mlp, - ls2: LayerScale, -} - -impl Block { - fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> { - let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; - let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; - let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; - let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; - let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; - let ls2 = LayerScale::new(vb.pp("ls2"), dim)?; - Ok(Self { - norm1, - attn, - ls1, - norm2, - mlp, - ls2, - }) - } -} - -impl Module for Block { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let residual = xs; - let xs = self - .ls1 - .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?; - let xs = (xs + residual)?; - let residual = &xs; - let xs = self - .ls2 - .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?; - xs + residual - } -} - -#[derive(Debug)] -struct PatchEmbed { - proj: candle_nn::Conv2d, - patch_size: (usize, usize), - num_patches: usize, -} - -impl PatchEmbed { - fn new( - vb: VarBuilder, - img_size: usize, - patch_size: usize, - in_chans: usize, - embed_dim: usize, - ) -> Result<Self> { - let config = candle_nn::Conv2dConfig { - stride: patch_size, - ..Default::default() - }; - let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?; - let num_patches = (img_size / patch_size) * (img_size / patch_size); - Ok(Self { - proj, - patch_size: (patch_size, patch_size), - num_patches, - }) - } -} - -impl Module for PatchEmbed { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let (_b, _c, h, w) = xs.dims4()?; - let (patch_h, patch_w) = self.patch_size; - if (h % patch_h) != 0 { - candle::bail!("image height {h} is not a multiple of patch height {patch_h}") - } - if (w % patch_w) != 0 { - candle::bail!("image width {w} is not a multiple of patch width {patch_w}") - } - let xs = self.proj.forward(xs)?; - let (b, c, h, w) = xs.dims4()?; - // flatten embeddings. - xs.reshape((b, c, h * w))?.transpose(1, 2) - } -} - -#[derive(Debug)] -pub struct DinoVisionTransformer { - patch_embed: PatchEmbed, - cls_token: Tensor, - pos_embed: Tensor, - blocks: Vec<Block>, - norm: LayerNorm, - head: Linear, -} - -impl DinoVisionTransformer { - pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> { - let patch_embed = - PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?; - let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; - let num_tokens = 1; - let pos_embed = vb.get( - (1, patch_embed.num_patches + num_tokens, embed_dim), - "pos_embed", - )?; - let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?; - let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?; - let vb_b = vb.pp("blocks"); - let blocks = (0..depth) - .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) - .collect::<Result<Vec<_>>>()?; - Ok(Self { - patch_embed, - cls_token, - pos_embed, - blocks, - norm, - head, - }) - } - - fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> { - let npatch = xs.dim(1)? - 1; - let n = self.pos_embed.dim(1)? - 1; - let sqrt_n = (n as f64).sqrt(); - if npatch == n && w == h { - return Ok(xs.clone()); - } - let class_pos_embed = self.pos_embed.i((.., ..1))?; - let patch_pos_embed = self.pos_embed.i((.., 1..))?; - let dim = xs.dim(D::Minus1)?; - let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); - let patch_pos_embed = patch_pos_embed - .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? - .transpose(2, 3)? - .transpose(1, 2)?; - // This uses bicubic interpolation in the original implementation. - let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; - let el_count = patch_pos_embed.shape().elem_count(); - let patch_pos_embed = - patch_pos_embed - .transpose(1, 2)? - .transpose(2, 3)? - .reshape((1, el_count / dim, dim))?; - Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1) - } - - fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> { - let (_b, _nc, w, h) = xs.dims4()?; - let xs = self.patch_embed.forward(xs)?; - let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; - &xs + &self.interpolate_pos_encoding(&xs, w, h)? - } -} - -impl Module for DinoVisionTransformer { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let mut xs = self.prepare_tokens_with_mask(xs)?; - for blk in self.blocks.iter() { - xs = blk.forward(&xs)? - } - let xs = self.norm.forward(&xs)?; - let xs_norm_clstoken = xs.i((.., 0))?; - let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?; - let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?; - self.head.forward(&xs) - } -} - -pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> { - DinoVisionTransformer::new(vb, 12, 384, 6) -} #[derive(Parser)] struct Args { #[arg(long)] @@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> { let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); - let model = vit_small(vb)?; + let model = dinov2::vit_small(vb)?; println!("model built"); let logits = model.forward(&image.unsqueeze(0)?)?; let prs = candle_nn::ops::softmax(&logits, D::Minus1)? diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs index cbe2c90a..1e45e301 100644 --- a/candle-examples/examples/efficientnet/main.rs +++ b/candle-examples/examples/efficientnet/main.rs @@ -8,340 +8,11 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig}; use clap::{Parser, ValueEnum}; -use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn as nn; -use nn::{Module, VarBuilder}; - -// Based on the Python version from torchvision. -// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47 -#[derive(Debug, Clone, Copy)] -pub struct MBConvConfig { - expand_ratio: f64, - kernel: usize, - stride: usize, - input_channels: usize, - out_channels: usize, - num_layers: usize, -} - -fn make_divisible(v: f64, divisor: usize) -> usize { - let min_value = divisor; - let new_v = usize::max( - min_value, - (v + divisor as f64 * 0.5) as usize / divisor * divisor, - ); - if (new_v as f64) < 0.9 * v { - new_v + divisor - } else { - new_v - } -} - -fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> { - let bneck_conf = |e, k, s, i, o, n| { - let input_channels = make_divisible(i as f64 * width_mult, 8); - let out_channels = make_divisible(o as f64 * width_mult, 8); - let num_layers = (n as f64 * depth_mult).ceil() as usize; - MBConvConfig { - expand_ratio: e, - kernel: k, - stride: s, - input_channels, - out_channels, - num_layers, - } - }; - vec![ - bneck_conf(1., 3, 1, 32, 16, 1), - bneck_conf(6., 3, 2, 16, 24, 2), - bneck_conf(6., 5, 2, 24, 40, 2), - bneck_conf(6., 3, 2, 40, 80, 3), - bneck_conf(6., 5, 1, 80, 112, 3), - bneck_conf(6., 5, 2, 112, 192, 4), - bneck_conf(6., 3, 1, 192, 320, 1), - ] -} - -impl MBConvConfig { - fn b0() -> Vec<Self> { - bneck_confs(1.0, 1.0) - } - fn b1() -> Vec<Self> { - bneck_confs(1.0, 1.1) - } - fn b2() -> Vec<Self> { - bneck_confs(1.1, 1.2) - } - fn b3() -> Vec<Self> { - bneck_confs(1.2, 1.4) - } - fn b4() -> Vec<Self> { - bneck_confs(1.4, 1.8) - } - fn b5() -> Vec<Self> { - bneck_confs(1.6, 2.2) - } - fn b6() -> Vec<Self> { - bneck_confs(1.8, 2.6) - } - fn b7() -> Vec<Self> { - bneck_confs(2.0, 3.1) - } -} - -/// Conv2D with same padding. -#[derive(Debug)] -struct Conv2DSame { - conv2d: nn::Conv2d, - s: usize, - k: usize, -} - -impl Conv2DSame { - fn new( - vb: VarBuilder, - i: usize, - o: usize, - k: usize, - stride: usize, - groups: usize, - bias: bool, - ) -> Result<Self> { - let conv_config = nn::Conv2dConfig { - stride, - groups, - ..Default::default() - }; - let conv2d = if bias { - nn::conv2d(i, o, k, conv_config, vb)? - } else { - nn::conv2d_no_bias(i, o, k, conv_config, vb)? - }; - Ok(Self { - conv2d, - s: stride, - k, - }) - } -} - -impl Module for Conv2DSame { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let s = self.s; - let k = self.k; - let (_, _, ih, iw) = xs.dims4()?; - let oh = (ih + s - 1) / s; - let ow = (iw + s - 1) / s; - let pad_h = usize::max((oh - 1) * s + k - ih, 0); - let pad_w = usize::max((ow - 1) * s + k - iw, 0); - if pad_h > 0 || pad_w > 0 { - let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?; - let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?; - self.conv2d.forward(&xs) - } else { - self.conv2d.forward(xs) - } - } -} - -#[derive(Debug)] -struct ConvNormActivation { - conv2d: Conv2DSame, - bn2d: nn::BatchNorm, - activation: bool, -} - -impl ConvNormActivation { - fn new( - vb: VarBuilder, - i: usize, - o: usize, - k: usize, - stride: usize, - groups: usize, - ) -> Result<Self> { - let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?; - let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?; - Ok(Self { - conv2d, - bn2d, - activation: true, - }) - } - - fn no_activation(self) -> Self { - Self { - activation: false, - ..self - } - } -} - -impl Module for ConvNormActivation { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let xs = self.conv2d.forward(xs)?; - let xs = self.bn2d.forward(&xs)?; - if self.activation { - swish(&xs) - } else { - Ok(xs) - } - } -} - -#[derive(Debug)] -struct SqueezeExcitation { - fc1: Conv2DSame, - fc2: Conv2DSame, -} - -impl SqueezeExcitation { - fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> { - let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?; - let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?; - Ok(Self { fc1, fc2 }) - } -} - -impl Module for SqueezeExcitation { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let residual = xs; - // equivalent to adaptive_avg_pool2d([1, 1]) - let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; - let xs = self.fc1.forward(&xs)?; - let xs = swish(&xs)?; - let xs = self.fc2.forward(&xs)?; - let xs = nn::ops::sigmoid(&xs)?; - residual.broadcast_mul(&xs) - } -} - -#[derive(Debug)] -struct MBConv { - expand_cna: Option<ConvNormActivation>, - depthwise_cna: ConvNormActivation, - squeeze_excitation: SqueezeExcitation, - project_cna: ConvNormActivation, - config: MBConvConfig, -} - -impl MBConv { - fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> { - let vb = vb.pp("block"); - let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8); - let expand_cna = if exp != c.input_channels { - Some(ConvNormActivation::new( - vb.pp("0"), - c.input_channels, - exp, - 1, - 1, - 1, - )?) - } else { - None - }; - let start_index = if expand_cna.is_some() { 1 } else { 0 }; - let depthwise_cna = - ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?; - let squeeze_channels = usize::max(1, c.input_channels / 4); - let squeeze_excitation = - SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?; - let project_cna = - ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)? - .no_activation(); - Ok(Self { - expand_cna, - depthwise_cna, - squeeze_excitation, - project_cna, - config: c, - }) - } -} - -impl Module for MBConv { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let use_res_connect = - self.config.stride == 1 && self.config.input_channels == self.config.out_channels; - let ys = match &self.expand_cna { - Some(expand_cna) => expand_cna.forward(xs)?, - None => xs.clone(), - }; - let ys = self.depthwise_cna.forward(&ys)?; - let ys = self.squeeze_excitation.forward(&ys)?; - let ys = self.project_cna.forward(&ys)?; - if use_res_connect { - ys + xs - } else { - Ok(ys) - } - } -} - -fn swish(s: &Tensor) -> Result<Tensor> { - s * nn::ops::sigmoid(s)? -} - -#[derive(Debug)] -struct EfficientNet { - init_cna: ConvNormActivation, - blocks: Vec<MBConv>, - final_cna: ConvNormActivation, - classifier: nn::Linear, -} - -impl EfficientNet { - fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> { - let f_p = p.pp("features"); - let first_in_c = configs[0].input_channels; - let last_out_c = configs.last().unwrap().out_channels; - let final_out_c = 4 * last_out_c; - let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; - let nconfigs = configs.len(); - let mut blocks = vec![]; - for (index, cnf) in configs.into_iter().enumerate() { - let f_p = f_p.pp(index + 1); - for r_index in 0..cnf.num_layers { - let cnf = if r_index == 0 { - cnf - } else { - MBConvConfig { - input_channels: cnf.out_channels, - stride: 1, - ..cnf - } - }; - blocks.push(MBConv::new(f_p.pp(r_index), cnf)?) - } - } - let final_cna = - ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?; - let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?; - Ok(Self { - init_cna, - blocks, - final_cna, - classifier, - }) - } -} - -impl Module for EfficientNet { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let mut xs = self.init_cna.forward(xs)?; - for block in self.blocks.iter() { - xs = block.forward(&xs)? - } - let xs = self.final_cna.forward(&xs)?; - // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1) - let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?; - self.classifier.forward(&xs) - } -} - #[derive(Clone, Copy, Debug, ValueEnum)] enum Which { B0, diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index a3f98d8e..c8179d33 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file}; use candle::{Device, Tensor}; use candle_transformers::generation::LogitsProcessor; -mod model; +use candle_transformers::models::quantized_llama as model; use model::ModelWeights; const DEFAULT_PROMPT: &str = "My favorite theorem is "; diff --git a/candle-examples/examples/quantized/model.rs b/candle-examples/examples/quantized/model.rs deleted file mode 100644 index da0bd0b0..00000000 --- a/candle-examples/examples/quantized/model.rs +++ /dev/null @@ -1,371 +0,0 @@ -use std::collections::HashMap; - -use candle::quantized::QTensor; -use candle::quantized::{ggml_file, gguf_file}; -use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Module}; - -pub const MAX_SEQ_LEN: usize = 4096; - -struct RmsNorm { - inner: candle_nn::LayerNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(scale: QTensor, eps: f32) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let scale = scale.dequantize(&Device::Cpu)?; - let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); - Ok(Self { inner, span }) - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -// QMatMul wrapper adding some tracing. -struct QMatMul { - inner: candle::quantized::QMatMul, - span: tracing::Span, -} - -impl QMatMul { - fn from_qtensor(qtensor: QTensor) -> Self { - let inner = candle::quantized::QMatMul::from_qtensor(qtensor); - let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); - Self { inner, span } - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(xs) - } -} - -struct LayerWeights { - attention_wq: QMatMul, - attention_wk: QMatMul, - attention_wv: QMatMul, - attention_wo: QMatMul, - attention_norm: RmsNorm, - feed_forward_w1: QMatMul, - feed_forward_w2: QMatMul, - feed_forward_w3: QMatMul, - ffn_norm: RmsNorm, - n_head: usize, - n_kv_head: usize, - head_dim: usize, - cos: Tensor, - sin: Tensor, - kv_cache: Option<(Tensor, Tensor)>, - span_attn: tracing::Span, - span_rot: tracing::Span, - span_mlp: tracing::Span, -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - -impl LayerWeights { - fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let _enter = self.span_rot.enter(); - let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; - let cos = self - .cos - .narrow(0, index_pos, seq_len)? - .reshape((seq_len, n_embd / 2, 1))?; - let sin = self - .sin - .narrow(0, index_pos, seq_len)? - .reshape((seq_len, n_embd / 2, 1))?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; - // This mimics the llama.cpp behavior. - // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 - // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. - // The resulting y0 and y1 are also interleaved with: - // y0 = x0*cos - x1*sin - // y1 = x0*sin + x1*cos - let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; - let x0 = x.narrow(D::Minus1, 0, 1)?; - let x1 = x.narrow(D::Minus1, 1, 1)?; - let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; - let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; - let rope = Tensor::cat(&[y0, y1], D::Minus1)?; - let rope = rope.flatten_from(D::Minus2)?; - Ok(rope) - } - - fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> { - let _enter = self.span_attn.enter(); - let (b_sz, seq_len, n_embd) = x.dims3()?; - let q = self.attention_wq.forward(x)?; - let k = self.attention_wk.forward(x)?; - let v = self.attention_wv.forward(x)?; - - let q = q - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? - .transpose(1, 2)?; - let k = k - .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? - .transpose(1, 2)?; - let v = v - .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? - .transpose(1, 2)?; - - let q = self.apply_rotary_emb(&q, index_pos)?; - let k = self.apply_rotary_emb(&k, index_pos)?; - - let (k, v) = match &self.kv_cache { - None => (k, v), - Some((k_cache, v_cache)) => { - if index_pos == 0 { - (k, v) - } else { - let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; - let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; - (k, v) - } - } - }; - self.kv_cache = Some((k.clone(), v.clone())); - - // Support for MQA, useful for 70B models. - let k = self.repeat_kv(k)?; - let v = self.repeat_kv(v)?; - - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = mask.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = self.attention_wo.forward(&y)?; - Ok(y) - } - - fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { - let n_rep = self.n_head / self.n_kv_head; - if n_rep == 1 { - Ok(x) - } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; - let x = x - .unsqueeze(2)? - .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? - .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; - Ok(x) - } - } -} - -pub struct ModelWeights { - tok_embeddings: Embedding, - layers: Vec<LayerWeights>, - norm: RmsNorm, - output: QMatMul, - masks: HashMap<usize, Tensor>, - span: tracing::Span, - span_output: tracing::Span, -} - -fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { - let theta: Vec<_> = (0..head_dim) - .step_by(2) - .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? - .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let cos = idx_theta.cos()?; - let sin = idx_theta.sin()?; - Ok((cos, sin)) -} - -impl ModelWeights { - pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { - let cpu = &Device::Cpu; - let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; - let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; - let tok_embeddings = ct.remove("tok_embeddings.weight")?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; - let output = ct.remove("output.weight")?; - let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); - for layer_idx in 0..ct.hparams.n_layer { - let prefix = format!("layers.{layer_idx}"); - let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; - let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; - let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; - let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; - let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; - let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; - let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; - let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; - let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; - let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); - let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); - layers.push(LayerWeights { - attention_wq: QMatMul::from_qtensor(attention_wq), - attention_wk: QMatMul::from_qtensor(attention_wk), - attention_wv: QMatMul::from_qtensor(attention_wv), - attention_wo: QMatMul::from_qtensor(attention_wo), - attention_norm: RmsNorm::new(attention_norm, 1e-5)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), - ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, - n_head: ct.hparams.n_head as usize, - n_kv_head: ct.hparams.n_head as usize / gqa, - head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, - cos: cos.clone(), - sin: sin.clone(), - kv_cache: None, - span_attn, - span_rot, - span_mlp, - }) - } - let span = tracing::span!(tracing::Level::TRACE, "model"); - let span_output = tracing::span!(tracing::Level::TRACE, "output"); - Ok(Self { - tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), - layers, - norm, - output: QMatMul::from_qtensor(output), - masks: HashMap::new(), - span, - span_output, - }) - } - - pub fn from_gguf<R: std::io::Seek + std::io::Read>( - ct: gguf_file::Content, - reader: &mut R, - ) -> Result<Self> { - let cpu = &Device::Cpu; - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - - // Parameter extraction from metadata. - let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("llama.block_count")?.to_u32()? as usize; - let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; - // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; - - let rope_freq_base = md_get("llama.rope.freq_base") - .and_then(|m| m.to_f32()) - .unwrap_or(10000f32); - let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; - - let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; - let output = ct.tensor(reader, "output.weight")?; - let mut layers = Vec::with_capacity(block_count); - for layer_idx in 0..block_count { - let prefix = format!("blk.{layer_idx}"); - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; - let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; - let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; - let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; - let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; - let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; - let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); - let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); - layers.push(LayerWeights { - attention_wq: QMatMul::from_qtensor(attention_wq), - attention_wk: QMatMul::from_qtensor(attention_wk), - attention_wv: QMatMul::from_qtensor(attention_wv), - attention_wo: QMatMul::from_qtensor(attention_wo), - attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), - ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, - n_head: head_count, - n_kv_head: head_count_kv, - head_dim: embedding_length / head_count, - cos: cos.clone(), - sin: sin.clone(), - kv_cache: None, - span_attn, - span_rot, - span_mlp, - }) - } - let span = tracing::span!(tracing::Level::TRACE, "model"); - let span_output = tracing::span!(tracing::Level::TRACE, "output"); - Ok(Self { - tok_embeddings: Embedding::new(tok_embeddings, embedding_length), - layers, - norm, - output: QMatMul::from_qtensor(output), - masks: HashMap::new(), - span, - span_output, - }) - } - - fn mask(&mut self, t: usize) -> Result<Tensor> { - if let Some(mask) = self.masks.get(&t) { - Ok(mask.clone()) - } else { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; - self.masks.insert(t, mask.clone()); - Ok(mask) - } - } - - pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (_b_sz, seq_len) = x.dims2()?; - let mask = self.mask(seq_len)?; - let _enter = self.span.enter(); - let mut layer_in = self.tok_embeddings.forward(x)?; - for layer in self.layers.iter_mut() { - let x = layer_in; - let residual = &x; - let x = layer.attention_norm.forward(&x)?; - let attn = layer.forward_attn(&x, &mask, index_pos)?; - let x = (attn + residual)?; - - // MLP - let _enter = layer.span_mlp.enter(); - let residual = &x; - let x = layer.ffn_norm.forward(&x)?; - let w1 = layer.feed_forward_w1.forward(&x)?; - let w3 = layer.feed_forward_w3.forward(&x)?; - let mlp = layer - .feed_forward_w2 - .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; - layer_in = (mlp + residual)?; - } - let x = self.norm.forward(&layer_in)?; - let x = x.i((.., seq_len - 1, ..))?; - let _enter = self.span_output.enter(); - self.output.forward(&x) - } -} diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 9ce2f158..21ba0415 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -7,108 +7,11 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -pub mod model_image_encoder; -pub mod model_mask_decoder; -pub mod model_prompt_encoder; -pub mod model_sam; -pub mod model_tiny_vit; -pub mod model_transformer; - -use candle::{DType, Result, Tensor}; -use candle_nn::{Module, VarBuilder}; +use candle::DType; +use candle_nn::VarBuilder; +use candle_transformers::models::segment_anything::sam; use clap::Parser; -pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { - let inner = if bias { - candle_nn::linear(in_dim, out_dim, vb)? - } else { - candle_nn::linear_no_bias(in_dim, out_dim, vb)? - }; - let span = tracing::span!(tracing::Level::TRACE, "linear"); - Ok(Linear { inner, span }) -} - -#[derive(Debug)] -pub struct LayerNorm2d { - weight: Tensor, - bias: Tensor, - num_channels: usize, - eps: f64, -} - -impl LayerNorm2d { - pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> { - let weight = vb.get(num_channels, "weight")?; - let bias = vb.get(num_channels, "bias")?; - Ok(Self { - weight, - bias, - num_channels, - eps, - }) - } -} - -impl Module for LayerNorm2d { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let u = xs.mean_keepdim(1)?; - let xs = xs.broadcast_sub(&u)?; - let s = xs.sqr()?.mean_keepdim(1)?; - let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; - xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? - .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) - } -} - -#[derive(Debug)] -pub struct MlpBlock { - lin1: Linear, - lin2: Linear, - activation: candle_nn::Activation, - span: tracing::Span, -} - -impl MlpBlock { - pub fn new( - embedding_dim: usize, - mlp_dim: usize, - activation: candle_nn::Activation, - vb: VarBuilder, - ) -> Result<Self> { - let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; - let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; - let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); - Ok(Self { - lin1, - lin2, - activation, - span, - }) - } -} - -impl Module for MlpBlock { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - xs.apply(&self.lin1)? - .apply(&self.activation)? - .apply(&self.lin2) - } -} - -#[derive(Debug)] -pub struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Parser)] struct Args { #[arg(long)] @@ -173,7 +76,7 @@ pub fn main() -> anyhow::Result<()> { let (_c, h, w) = image.dims3()?; (image, h, w) } else { - let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?; + let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?; (image.to_device(&device)?, h, w) }; println!("loaded image {image:?}"); @@ -195,9 +98,9 @@ pub fn main() -> anyhow::Result<()> { let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let sam = if args.use_tiny { - model_sam::Sam::new_tiny(vb)? // tiny vit_t + sam::Sam::new_tiny(vb)? // tiny vit_t } else { - model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b + sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b }; if args.generate_masks { diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs deleted file mode 100644 index 76cd15d0..00000000 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ /dev/null @@ -1,483 +0,0 @@ -use candle::{DType, IndexOp, Result, Tensor}; -use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder}; - -#[derive(Debug)] -struct PatchEmbed { - proj: candle_nn::Conv2d, - span: tracing::Span, -} - -impl PatchEmbed { - fn new( - in_chans: usize, - embed_dim: usize, - k_size: usize, - stride: usize, - padding: usize, - vb: VarBuilder, - ) -> Result<Self> { - let cfg = candle_nn::Conv2dConfig { - stride, - padding, - ..Default::default() - }; - let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; - let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); - Ok(Self { proj, span }) - } -} - -impl Module for PatchEmbed { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - xs.apply(&self.proj)?.permute((0, 2, 3, 1)) - } -} - -// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final -// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096 -// (attn.reshape((b, q_h, q_w, k_h, k_w))? -// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? -// .reshape((b, q_h * q_w, k_h * k_w)) -// Ideally we would perform this operation in place but this is not supported in candle at the -// moment. We should also investigate using f16 rather than f32. -struct Add3(usize, usize, usize, usize, usize); -impl candle::CustomOp3 for Add3 { - fn name(&self) -> &'static str { - "add3" - } - - fn cpu_fwd( - &self, - s1: &candle::CpuStorage, - l1: &candle::Layout, - s2: &candle::CpuStorage, - l2: &candle::Layout, - s3: &candle::CpuStorage, - l3: &candle::Layout, - ) -> Result<(candle::CpuStorage, candle::Shape)> { - use rayon::prelude::*; - - let Add3(b, q_h, q_w, k_h, k_w) = *self; - let s1 = s1.as_slice::<f32>()?; - let s1 = match l1.contiguous_offsets() { - None => candle::bail!("input1 has to be contiguous"), - Some((o1, o2)) => &s1[o1..o2], - }; - let s2 = s2.as_slice::<f32>()?; - let s2 = match l2.contiguous_offsets() { - None => candle::bail!("input2 has to be contiguous"), - Some((o1, o2)) => &s2[o1..o2], - }; - let s3 = s3.as_slice::<f32>()?; - let s3 = match l3.contiguous_offsets() { - None => candle::bail!("input3 has to be contiguous"), - Some((o1, o2)) => &s3[o1..o2], - }; - let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w]; - dst.par_chunks_exact_mut(k_h * k_w) - .enumerate() - .for_each(|(b_idx, dst)| { - let s1_idx = b_idx * k_h * k_w; - let s2_idx = b_idx * k_h; - let s3_idx = b_idx * k_w; - for h_idx in 0..k_h { - let s1_idx = s1_idx + h_idx * k_w; - let s2_idx = s2_idx + h_idx; - let dst_idx = h_idx * k_w; - for w_idx in 0..k_w { - let s1_idx = s1_idx + w_idx; - let s3_idx = s3_idx + w_idx; - let dst_idx = dst_idx + w_idx; - dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx] - } - } - }); - let dst = candle::WithDType::to_cpu_storage_owned(dst); - Ok((dst, (b, q_h * q_w, k_h * k_w).into())) - } -} - -#[derive(Debug)] -struct Attention { - qkv: crate::Linear, - proj: crate::Linear, - num_heads: usize, - scale: f64, - rel_pos_hw: Option<(Tensor, Tensor)>, - span: tracing::Span, - span_matmul: tracing::Span, - span_rel_pos: tracing::Span, - span_softmax: tracing::Span, -} - -impl Attention { - fn new( - dim: usize, - num_heads: usize, - qkv_bias: bool, - use_rel_pos: bool, - input_size: (usize, usize), - vb: VarBuilder, - ) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "attention"); - let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); - let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos"); - let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); - let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; - let proj = crate::linear(vb.pp("proj"), dim, dim, true)?; - let head_dim = dim / num_heads; - let scale = 1. / (head_dim as f64).sqrt(); - let rel_pos_hw = if use_rel_pos { - let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?; - let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?; - Some((h, w)) - } else { - None - }; - Ok(Self { - qkv, - proj, - num_heads, - scale, - rel_pos_hw, - span, - span_matmul, - span_rel_pos, - span_softmax, - }) - } - - fn add_decomposed_rel_pos( - &self, - attn: Tensor, - q: &Tensor, - (q_h, q_w): (usize, usize), - (k_h, k_w): (usize, usize), - ) -> Result<Tensor> { - match &self.rel_pos_hw { - Some((rel_pos_h, rel_pos_w)) => { - let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?; - let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?; - let (b, _, dim) = q.dims3()?; - let r_q = q.reshape((b, q_h, q_w, dim))?; - // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) - let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?; - // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) - let rel_w = r_q - .transpose(1, 2)? // -> bwhc - .contiguous()? - .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk - .transpose(1, 2)? - .contiguous()?; - if attn.device().is_cpu() { - let op = Add3(b, q_h, q_w, k_h, k_w); - attn.apply_op3_no_bwd(&rel_h, &rel_w, &op) - } else { - (attn.reshape((b, q_h, q_w, k_h, k_w))? - + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? - .reshape((b, q_h * q_w, k_h * k_w)) - } - } - None => Ok(attn), - } - } -} - -fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> { - let max_rel_dist = 2 * usize::max(q_size, k_size) - 1; - let dev = rel_pos.device(); - let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist { - todo!("interpolation") - } else { - rel_pos - }; - let q_coords = Tensor::arange(0u32, q_size as u32, dev)? - .reshape((q_size, 1))? - .to_dtype(DType::F32)?; - let k_coords = Tensor::arange(0u32, k_size as u32, dev)? - .reshape((1, k_size))? - .to_dtype(DType::F32)?; - let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?; - let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?; - let relative_coords = (q_coords.broadcast_sub(&k_coords)? - + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?; - let (d1, d2) = relative_coords.dims2()?; - let relative_coords = relative_coords.to_dtype(DType::U32)?; - rel_pos_resized - .index_select(&relative_coords.reshape(d1 * d2)?, 0)? - .reshape((d1, d2, ())) -} - -impl Module for Attention { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let (b, h, w, c) = xs.dims4()?; - let qkv = self - .qkv - .forward(&xs.flatten_to(1)?)? - .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))? - .permute((2, 0, 3, 1, 4))? - .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?; - let q = qkv.i(0)?; - let k = qkv.i(1)?; - let v = qkv.i(2)?; - let attn = { - let _enter = self.span_matmul.enter(); - (&q * self.scale)?.matmul(&k.t()?)? - }; - let attn = { - let _enter = self.span_rel_pos.enter(); - self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))? - }; - let attn = { - let _enter = self.span_softmax.enter(); - candle_nn::ops::softmax_last_dim(&attn)? - }; - let attn = { - let _enter = self.span_matmul.enter(); - attn.matmul(&v)? - }; - let attn = attn - .reshape((b, self.num_heads, h, w, c / self.num_heads))? - .permute((0, 2, 3, 1, 4))? - .reshape((b, h * w, c))?; - self.proj.forward(&attn)?.reshape((b, h, w, c)) - } -} - -#[derive(Debug)] -struct Block { - norm1: LayerNorm, - attn: Attention, - norm2: LayerNorm, - mlp: crate::MlpBlock, - window_size: usize, - span: tracing::Span, -} - -impl Block { - fn new( - dim: usize, - num_heads: usize, - qkv_bias: bool, - use_rel_pos: bool, - window_size: usize, - input_size: (usize, usize), - vb: VarBuilder, - ) -> Result<Self> { - let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; - let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; - let input_size_attn = if window_size == 0 { - input_size - } else { - (window_size, window_size) - }; - let attn = Attention::new( - dim, - num_heads, - qkv_bias, - use_rel_pos, - input_size_attn, - vb.pp("attn"), - )?; - let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; - let span = tracing::span!(tracing::Level::TRACE, "ie-block"); - Ok(Self { - norm1, - attn, - norm2, - mlp, - window_size, - span, - }) - } -} - -fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> { - let (b, h, w, c) = xs.dims4()?; - let pad_h = (window_size - h % window_size) % window_size; - let pad_w = (window_size - w % window_size) % window_size; - let xs = if pad_h > 0 { - xs.pad_with_zeros(1, 0, pad_h)? - } else { - xs - }; - let xs = if pad_w > 0 { - xs.pad_with_zeros(2, 0, pad_w)? - } else { - xs - }; - let (h_p, w_p) = (h + pad_h, w + pad_w); - let windows = xs - .reshape(( - b, - h_p / window_size, - window_size, - w_p / window_size, - window_size, - c, - ))? - .transpose(2, 3)? - .contiguous()? - .flatten_to(2)?; - Ok((windows, (h_p, w_p))) -} - -fn window_unpartition( - windows: Tensor, - window_size: usize, - (h_p, w_p): (usize, usize), - (h, w): (usize, usize), -) -> Result<Tensor> { - let b = windows.dim(0)? / (h_p * w_p / window_size / window_size); - let xs = windows - .reshape(( - b, - h_p / window_size, - w_p / window_size, - window_size, - window_size, - windows.elem_count() / b / h_p / w_p, - ))? - .transpose(2, 3)? - .contiguous()? - .reshape((b, h_p, w_p, ()))?; - let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs }; - let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs }; - Ok(xs) -} - -impl Module for Block { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let shortcut = xs; - let xs = self.norm1.forward(xs)?; - let hw = (xs.dim(1)?, xs.dim(2)?); - let (xs, pad_hw) = if self.window_size > 0 { - window_partition(xs, self.window_size)? - } else { - (xs, (0, 0)) - }; - let xs = self.attn.forward(&xs)?; - let xs = if self.window_size > 0 { - window_unpartition(xs, self.window_size, pad_hw, hw)? - } else { - xs - }; - let xs = (xs + shortcut)?; - &xs + xs.apply(&self.norm2)?.apply(&self.mlp)? - } -} - -#[derive(Debug)] -pub struct ImageEncoderViT { - patch_embed: PatchEmbed, - blocks: Vec<Block>, - neck_conv1: candle_nn::Conv2d, - neck_ln1: crate::LayerNorm2d, - neck_conv2: candle_nn::Conv2d, - neck_ln2: crate::LayerNorm2d, - pos_embed: Option<Tensor>, - span: tracing::Span, -} - -impl ImageEncoderViT { - #[allow(clippy::too_many_arguments)] - pub fn new( - img_size: usize, - patch_size: usize, - in_chans: usize, - embed_dim: usize, - depth: usize, - num_heads: usize, - out_chans: usize, - qkv_bias: bool, - use_rel_pos: bool, - use_abs_pos: bool, - window_size: usize, - global_attn_indexes: &[usize], - vb: VarBuilder, - ) -> Result<Self> { - let patch_embed = PatchEmbed::new( - in_chans, - embed_dim, - patch_size, - patch_size, - 0, - vb.pp("patch_embed"), - )?; - let mut blocks = Vec::with_capacity(depth); - let vb_b = vb.pp("blocks"); - for i in 0..depth { - let window_size = if global_attn_indexes.contains(&i) { - 0 - } else { - window_size - }; - let block = Block::new( - embed_dim, - num_heads, - qkv_bias, - use_rel_pos, - window_size, - (img_size / patch_size, img_size / patch_size), - vb_b.pp(i), - )?; - blocks.push(block) - } - let neck_conv1 = candle_nn::conv2d_no_bias( - embed_dim, - out_chans, - 1, - Default::default(), - vb.pp("neck.0"), - )?; - let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?; - let cfg = candle_nn::Conv2dConfig { - padding: 1, - ..Default::default() - }; - let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; - let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?; - let pos_embed = if use_abs_pos { - let p = vb.get( - (1, img_size / patch_size, img_size / patch_size, embed_dim), - "pos_embed", - )?; - Some(p) - } else { - None - }; - let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit"); - Ok(Self { - patch_embed, - blocks, - neck_conv1, - neck_ln1, - neck_conv2, - neck_ln2, - pos_embed, - span, - }) - } -} - -impl Module for ImageEncoderViT { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let xs = self.patch_embed.forward(xs)?; - let mut xs = match &self.pos_embed { - Some(pos_embed) => (xs + pos_embed)?, - None => xs, - }; - for block in self.blocks.iter() { - xs = block.forward(&xs)? - } - xs.permute((0, 3, 1, 2))? - .apply(&self.neck_conv1)? - .apply(&self.neck_ln1)? - .apply(&self.neck_conv2)? - .apply(&self.neck_ln2) - } -} diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs deleted file mode 100644 index c02b44a7..00000000 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ /dev/null @@ -1,239 +0,0 @@ -use candle::{IndexOp, Result, Tensor}; -use candle_nn::{Module, VarBuilder}; - -use crate::model_transformer::TwoWayTransformer; - -#[derive(Debug)] -struct MlpMaskDecoder { - layers: Vec<crate::Linear>, - sigmoid_output: bool, - span: tracing::Span, -} - -impl MlpMaskDecoder { - fn new( - input_dim: usize, - hidden_dim: usize, - output_dim: usize, - num_layers: usize, - sigmoid_output: bool, - vb: VarBuilder, - ) -> Result<Self> { - let mut layers = Vec::with_capacity(num_layers); - let vb = vb.pp("layers"); - for i in 0..num_layers { - let in_dim = if i == 0 { input_dim } else { hidden_dim }; - let out_dim = if i + 1 == num_layers { - output_dim - } else { - hidden_dim - }; - let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?; - layers.push(layer) - } - let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder"); - Ok(Self { - layers, - sigmoid_output, - span, - }) - } -} - -impl Module for MlpMaskDecoder { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let mut xs = xs.clone(); - for (i, layer) in self.layers.iter().enumerate() { - xs = layer.forward(&xs)?; - if i + 1 < self.layers.len() { - xs = xs.relu()? - } - } - if self.sigmoid_output { - candle_nn::ops::sigmoid(&xs) - } else { - Ok(xs) - } - } -} - -#[derive(Debug)] -pub struct MaskDecoder { - iou_token: candle_nn::Embedding, - mask_tokens: candle_nn::Embedding, - iou_prediction_head: MlpMaskDecoder, - output_upscaling_conv1: candle_nn::ConvTranspose2d, - output_upscaling_ln: crate::LayerNorm2d, - output_upscaling_conv2: candle_nn::ConvTranspose2d, - num_mask_tokens: usize, - output_hypernetworks_mlps: Vec<MlpMaskDecoder>, - transformer: TwoWayTransformer, - span: tracing::Span, -} - -impl MaskDecoder { - pub fn new( - transformer_dim: usize, - num_multimask_outputs: usize, - iou_head_depth: usize, - iou_head_hidden_dim: usize, - vb: VarBuilder, - ) -> Result<Self> { - let num_mask_tokens = num_multimask_outputs + 1; - let iou_prediction_head = MlpMaskDecoder::new( - transformer_dim, - iou_head_hidden_dim, - num_mask_tokens, - iou_head_depth, - false, - vb.pp("iou_prediction_head"), - )?; - let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?; - let mask_tokens = - candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?; - let cfg = candle_nn::ConvTranspose2dConfig { - stride: 2, - ..Default::default() - }; - let output_upscaling_conv1 = candle_nn::conv_transpose2d( - transformer_dim, - transformer_dim / 4, - 2, - cfg, - vb.pp("output_upscaling.0"), - )?; - let output_upscaling_ln = - crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; - let output_upscaling_conv2 = candle_nn::conv_transpose2d( - transformer_dim / 4, - transformer_dim / 8, - 2, - cfg, - vb.pp("output_upscaling.3"), - )?; - let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens); - let vb_o = vb.pp("output_hypernetworks_mlps"); - for i in 0..num_mask_tokens { - let mlp = MlpMaskDecoder::new( - transformer_dim, - transformer_dim, - transformer_dim / 8, - 3, - false, - vb_o.pp(i), - )?; - output_hypernetworks_mlps.push(mlp) - } - let transformer = TwoWayTransformer::new( - /* depth */ 2, - /* embedding_dim */ transformer_dim, - /* num_heads */ 8, - /* mlp_dim */ 2048, - vb.pp("transformer"), - )?; - let span = tracing::span!(tracing::Level::TRACE, "mask-decoder"); - Ok(Self { - iou_token, - mask_tokens, - iou_prediction_head, - output_upscaling_conv1, - output_upscaling_ln, - output_upscaling_conv2, - num_mask_tokens, - output_hypernetworks_mlps, - transformer, - span, - }) - } - - pub fn forward( - &self, - image_embeddings: &Tensor, - image_pe: &Tensor, - sparse_prompt_embeddings: &Tensor, - dense_prompt_embeddings: &Tensor, - multimask_output: bool, - ) -> Result<(Tensor, Tensor)> { - let _enter = self.span.enter(); - let (masks, iou_pred) = self.predict_masks( - image_embeddings, - image_pe, - sparse_prompt_embeddings, - dense_prompt_embeddings, - )?; - let masks = if multimask_output { - masks.i((.., 1..))? - } else { - masks.i((.., 0..1))? - }; - let iou_pred = if multimask_output { - iou_pred.i((.., 1..))? - } else { - iou_pred.i((.., 0..1))? - }; - Ok((masks, iou_pred)) - } - - fn predict_masks( - &self, - image_embeddings: &Tensor, - image_pe: &Tensor, - sparse_prompt_embeddings: &Tensor, - dense_prompt_embeddings: &Tensor, - ) -> Result<(Tensor, Tensor)> { - // Concatenate ouput tokens. - let output_tokens = Tensor::cat( - &[self.iou_token.embeddings(), self.mask_tokens.embeddings()], - 0, - )?; - let (d1, d2) = output_tokens.dims2()?; - let output_tokens = - output_tokens - .unsqueeze(0)? - .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?; - let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?; - - // Expand per-image data in batch direction to be per mask - let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?; - let src = src.broadcast_add(dense_prompt_embeddings)?; - let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?; - let (b, c, h, w) = src.dims4()?; - - // Run the transformer - let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?; - let iou_token_out = hs.i((.., 0))?; - let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?; - - // Upscale mask embeddings and predict masks using the masks tokens. - let src = src.transpose(1, 2)?.reshape((b, c, h, w))?; - let upscaled_embedding = self - .output_upscaling_conv1 - .forward(&src)? - .apply(&self.output_upscaling_ln)? - .gelu()? - .apply(&self.output_upscaling_conv2)? - .gelu()?; - let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens); - for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() { - let h = mlp.forward(&mask_tokens_out.i((.., i))?)?; - hyper_in_list.push(h) - } - let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?; - let (b, c, h, w) = upscaled_embedding.dims4()?; - let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?; - let masks = masks.reshape((b, (), h, w))?; - - // Generate mask quality predictions. - let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?; - Ok((masks, iou_pred)) - } -} - -// Equivalent to torch.repeat_interleave -fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> { - let img = img.unsqueeze(dim + 1)?; - let mut dims = img.dims().to_vec(); - dims[dim + 1] = repeats; - img.broadcast_as(dims)?.flatten(dim, dim + 1) -} diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs deleted file mode 100644 index 7bbe8419..00000000 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ /dev/null @@ -1,239 +0,0 @@ -use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::VarBuilder; - -#[derive(Debug)] -struct PostionEmbeddingRandom { - positional_encoding_gaussian_matrix: Tensor, -} - -impl PostionEmbeddingRandom { - fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> { - let positional_encoding_gaussian_matrix = - vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?; - Ok(Self { - positional_encoding_gaussian_matrix, - }) - } - - fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> { - let coords = coords.affine(2., -1.)?; - let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?; - let coords = (coords * (2. * std::f64::consts::PI))?; - Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1) - } - - fn forward(&self, h: usize, w: usize) -> Result<Tensor> { - let device = self.positional_encoding_gaussian_matrix.device(); - let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?; - let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?; - let x_embed = (x_embed / w as f64)? - .reshape((1, ()))? - .broadcast_as((h, w))?; - let y_embed = (y_embed / h as f64)? - .reshape(((), 1))? - .broadcast_as((h, w))?; - let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?; - self.pe_encoding(&coords)?.permute((2, 0, 1)) - } - - fn forward_with_coords( - &self, - coords_input: &Tensor, - image_size: (usize, usize), - ) -> Result<Tensor> { - let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?; - let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?; - let c = coords_input.dim(D::Minus1)?; - let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?; - let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?; - self.pe_encoding(&coords) - } -} - -#[derive(Debug)] -pub struct PromptEncoder { - pe_layer: PostionEmbeddingRandom, - point_embeddings: Vec<candle_nn::Embedding>, - not_a_point_embed: candle_nn::Embedding, - mask_downscaling_conv1: candle_nn::Conv2d, - mask_downscaling_ln1: crate::LayerNorm2d, - mask_downscaling_conv2: candle_nn::Conv2d, - mask_downscaling_ln2: crate::LayerNorm2d, - mask_downscaling_conv3: candle_nn::Conv2d, - no_mask_embed: candle_nn::Embedding, - image_embedding_size: (usize, usize), - input_image_size: (usize, usize), - embed_dim: usize, - span: tracing::Span, -} - -impl PromptEncoder { - pub fn new( - embed_dim: usize, - image_embedding_size: (usize, usize), - input_image_size: (usize, usize), - mask_in_chans: usize, - vb: VarBuilder, - ) -> Result<Self> { - let num_points_embeddings = 4; - let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?; - let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?; - let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?; - let cfg = candle_nn::Conv2dConfig { - stride: 2, - ..Default::default() - }; - let mask_downscaling_conv1 = - candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?; - let mask_downscaling_conv2 = candle_nn::conv2d( - mask_in_chans / 4, - mask_in_chans, - 2, - cfg, - vb.pp("mask_downscaling.3"), - )?; - let mask_downscaling_conv3 = candle_nn::conv2d( - mask_in_chans, - embed_dim, - 1, - Default::default(), - vb.pp("mask_downscaling.6"), - )?; - let mask_downscaling_ln1 = - crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; - let mask_downscaling_ln2 = - crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; - let mut point_embeddings = Vec::with_capacity(num_points_embeddings); - let vb_e = vb.pp("point_embeddings"); - for i in 0..num_points_embeddings { - let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?; - point_embeddings.push(emb) - } - let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder"); - Ok(Self { - pe_layer, - point_embeddings, - not_a_point_embed, - mask_downscaling_conv1, - mask_downscaling_ln1, - mask_downscaling_conv2, - mask_downscaling_ln2, - mask_downscaling_conv3, - no_mask_embed, - image_embedding_size, - input_image_size, - embed_dim, - span, - }) - } - - pub fn get_dense_pe(&self) -> Result<Tensor> { - self.pe_layer - .forward(self.image_embedding_size.0, self.image_embedding_size.1)? - .unsqueeze(0) - } - - fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> { - masks - .apply(&self.mask_downscaling_conv1)? - .apply(&self.mask_downscaling_ln1)? - .gelu()? - .apply(&self.mask_downscaling_conv2)? - .apply(&self.mask_downscaling_ln2)? - .gelu()? - .apply(&self.mask_downscaling_conv3) - } - - fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> { - let points = (points + 0.5)?; - let dev = points.device(); - let (points, labels) = if pad { - let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?; - let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?; - let points = Tensor::cat(&[&points, &padding_point], 1)?; - let labels = Tensor::cat(&[labels, &padding_label], 1)?; - (points, labels) - } else { - (points, labels.clone()) - }; - let point_embedding = self - .pe_layer - .forward_with_coords(&points, self.input_image_size)?; - let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?; - let zeros = point_embedding.zeros_like()?; - let point_embedding = labels.lt(0f32)?.where_cond( - &self - .not_a_point_embed - .embeddings() - .broadcast_as(zeros.shape())?, - &point_embedding, - )?; - let labels0 = labels.eq(0f32)?.where_cond( - &self.point_embeddings[0] - .embeddings() - .broadcast_as(zeros.shape())?, - &zeros, - )?; - let point_embedding = (point_embedding + labels0)?; - let labels1 = labels.eq(1f32)?.where_cond( - &self.point_embeddings[1] - .embeddings() - .broadcast_as(zeros.shape())?, - &zeros, - )?; - let point_embedding = (point_embedding + labels1)?; - Ok(point_embedding) - } - - fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> { - let boxes = (boxes + 0.5)?; - let coords = boxes.reshape(((), 2, 2))?; - let corner_embedding = self - .pe_layer - .forward_with_coords(&coords, self.input_image_size)?; - let ce1 = corner_embedding.i((.., 0))?; - let ce2 = corner_embedding.i((.., 1))?; - let ce1 = (ce1 + self.point_embeddings[2].embeddings())?; - let ce2 = (ce2 + self.point_embeddings[3].embeddings())?; - Tensor::cat(&[&ce1, &ce2], 1) - } - - pub fn forward( - &self, - points: Option<(&Tensor, &Tensor)>, - boxes: Option<&Tensor>, - masks: Option<&Tensor>, - ) -> Result<(Tensor, Tensor)> { - let _enter = self.span.enter(); - let se_points = match points { - Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?), - None => None, - }; - let se_boxes = match boxes { - Some(boxes) => Some(self.embed_boxes(boxes)?), - None => None, - }; - let sparse_embeddings = match (se_points, se_boxes) { - (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?, - (Some(se_points), None) => se_points, - (None, Some(se_boxes)) => se_boxes, - (None, None) => { - Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)? - } - }; - - let dense_embeddings = match masks { - None => { - let emb = self.no_mask_embed.embeddings(); - emb.reshape((1, (), 1, 1))?.expand(( - 1, - emb.elem_count(), - self.image_embedding_size.0, - self.image_embedding_size.1, - ))? - } - Some(masks) => self.embed_masks(masks)?, - }; - Ok((sparse_embeddings, dense_embeddings)) - } -} diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs deleted file mode 100644 index b1a81af6..00000000 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ /dev/null @@ -1,411 +0,0 @@ -use candle::{DType, IndexOp, Result, Tensor}; -use candle_nn::{Module, VarBuilder}; - -use crate::model_image_encoder::ImageEncoderViT; -use crate::model_mask_decoder::MaskDecoder; -use crate::model_prompt_encoder::PromptEncoder; -use crate::model_tiny_vit::{tiny_vit_5m, TinyViT}; - -const PROMPT_EMBED_DIM: usize = 256; -pub const IMAGE_SIZE: usize = 1024; -const VIT_PATCH_SIZE: usize = 16; -const PRED_IOU_THRESH: f32 = 0.88; -const STABILITY_SCORE_OFFSET: f32 = 1.0; -const STABILITY_SCORE_THRESHOLD: f32 = 0.95; -const MODEL_MASK_THRESHOLD: f32 = 0.0; -const CROP_NMS_THRESH: f32 = 0.7; - -#[derive(Debug)] -enum ImageEncoder { - Original(ImageEncoderViT), - TinyViT(TinyViT), -} - -impl Module for ImageEncoder { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - match self { - Self::Original(vit) => vit.forward(xs), - Self::TinyViT(vit) => vit.forward(xs), - } - } -} - -#[derive(Debug)] -pub struct Sam { - image_encoder: ImageEncoder, - prompt_encoder: PromptEncoder, - mask_decoder: MaskDecoder, - pixel_mean: Tensor, - pixel_std: Tensor, -} - -impl Sam { - pub fn new( - encoder_embed_dim: usize, - encoder_depth: usize, - encoder_num_heads: usize, - encoder_global_attn_indexes: &[usize], - vb: VarBuilder, - ) -> Result<Self> { - let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; - - let image_encoder = ImageEncoderViT::new( - IMAGE_SIZE, - VIT_PATCH_SIZE, - 3, - encoder_embed_dim, - encoder_depth, - encoder_num_heads, - PROMPT_EMBED_DIM, - /* qkv_bias */ true, - /* use_rel_pos */ true, - /* use_abs_pos */ true, - /* window_size */ 14, - /* global_attn_indexes */ encoder_global_attn_indexes, - vb.pp("image_encoder"), - )?; - let prompt_encoder = PromptEncoder::new( - PROMPT_EMBED_DIM, - (image_embedding_size, image_embedding_size), - (IMAGE_SIZE, IMAGE_SIZE), - 16, - vb.pp("prompt_encoder"), - )?; - let mask_decoder = MaskDecoder::new( - PROMPT_EMBED_DIM, - /* num_multitask_outputs */ 3, - /* iou_head_depth */ 3, - /* iou_head_hidden_dim */ 256, - vb.pp("mask_decoder"), - )?; - let pixel_mean = - Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; - let pixel_std = - Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; - Ok(Self { - image_encoder: ImageEncoder::Original(image_encoder), - prompt_encoder, - mask_decoder, - pixel_std, - pixel_mean, - }) - } - - pub fn new_tiny(vb: VarBuilder) -> Result<Self> { - let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; - - let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?; - let prompt_encoder = PromptEncoder::new( - PROMPT_EMBED_DIM, - (image_embedding_size, image_embedding_size), - (IMAGE_SIZE, IMAGE_SIZE), - 16, - vb.pp("prompt_encoder"), - )?; - let mask_decoder = MaskDecoder::new( - PROMPT_EMBED_DIM, - /* num_multitask_outputs */ 3, - /* iou_head_depth */ 3, - /* iou_head_hidden_dim */ 256, - vb.pp("mask_decoder"), - )?; - let pixel_mean = - Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; - let pixel_std = - Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; - Ok(Self { - image_encoder: ImageEncoder::TinyViT(image_encoder), - prompt_encoder, - mask_decoder, - pixel_std, - pixel_mean, - }) - } - - pub fn forward( - &self, - img: &Tensor, - point: Option<(f64, f64)>, - multimask_output: bool, - ) -> Result<(Tensor, Tensor)> { - let (_c, original_h, original_w) = img.dims3()?; - let img = self.preprocess(img)?.unsqueeze(0)?; - let img_embeddings = self.image_encoder.forward(&img)?; - let image_pe = self.prompt_encoder.get_dense_pe()?; - let points = match point { - None => None, - Some((x, y)) => { - let points = Tensor::new( - &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]], - img.device(), - )?; - let labels = Tensor::ones((1, 1), DType::F32, img.device())?; - Some((points, labels)) - } - }; - let points = points.as_ref().map(|(x, y)| (x, y)); - let (sparse_prompt_embeddings, dense_prompt_embeddings) = - self.prompt_encoder.forward(points, None, None)?; - let (low_res_mask, iou_predictions) = self.mask_decoder.forward( - &img_embeddings, - &image_pe, - &sparse_prompt_embeddings, - &dense_prompt_embeddings, - multimask_output, - )?; - let mask = low_res_mask - .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)? - .get(0)? - .i((.., ..original_h, ..original_w))?; - Ok((mask, iou_predictions)) - } - - pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> { - let img = img - .broadcast_mul(&self.pixel_std)? - .broadcast_add(&self.pixel_mean)?; - img.maximum(&img.zeros_like()?)? - .minimum(&(img.ones_like()? * 255.)?) - } - - pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> { - let (_c, h, w) = img.dims3()?; - let img = img - .to_dtype(DType::F32)? - .broadcast_sub(&self.pixel_mean)? - .broadcast_div(&self.pixel_std)?; - if h > IMAGE_SIZE || w > IMAGE_SIZE { - candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}") - } - let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?; - img.pad_with_zeros(2, 0, IMAGE_SIZE - w) - } - - fn process_crop( - &self, - img: &Tensor, - cb: CropBox, - point_grids: &[(f64, f64)], - ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> { - // Crop the image and calculate embeddings. - let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?; - let img = self.preprocess(&img)?.unsqueeze(0)?; - let img_embeddings = self.image_encoder.forward(&img)?; - - let crop_w = cb.x1 - cb.x0; - let crop_h = cb.y1 - cb.y0; - - // Generate masks for this crop. - let image_pe = self.prompt_encoder.get_dense_pe()?; - let points = point_grids - .iter() - .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32]) - .collect::<Vec<_>>(); - - let mut bboxes = Vec::new(); - for points in points.chunks(64) { - // Run the model on this batch. - let points_len = points.len(); - let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?; - let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?; - let (sparse_prompt_embeddings, dense_prompt_embeddings) = - self.prompt_encoder - .forward(Some((&in_points, &in_labels)), None, None)?; - - let (low_res_mask, iou_predictions) = self.mask_decoder.forward( - &img_embeddings, - &image_pe, - &sparse_prompt_embeddings, - &dense_prompt_embeddings, - /* multimask_output */ true, - )?; - let low_res_mask = low_res_mask.flatten(0, 1)?; - let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?; - let dev = low_res_mask.device(); - - for (i, iou) in iou_predictions.iter().enumerate() { - // Filter by predicted IoU. - if *iou < PRED_IOU_THRESH { - continue; - } - let low_res_mask = low_res_mask.get(i)?; - - // Calculate stability score. - let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)? - .broadcast_as(low_res_mask.shape())?; - let intersections = low_res_mask - .ge(&bound)? - .to_dtype(DType::F32)? - .sum_all()? - .to_vec0::<f32>()?; - let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)? - .broadcast_as(low_res_mask.shape())?; - let unions = low_res_mask - .ge(&bound)? - .to_dtype(DType::F32)? - .sum_all()? - .to_vec0::<f32>()?; - let stability_score = intersections / unions; - if stability_score < STABILITY_SCORE_THRESHOLD { - continue; - } - - // Threshold masks and calculate boxes. - let low_res_mask = low_res_mask - .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)? - .to_dtype(DType::U32)?; - let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?; - let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?; - let min_max_x = min_max_indexes(&low_res_mask_per_x); - let min_max_y = min_max_indexes(&low_res_mask_per_y); - if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) { - let bbox = candle_examples::object_detection::Bbox { - xmin: x0 as f32, - ymin: y0 as f32, - xmax: x1 as f32, - ymax: y1 as f32, - confidence: *iou, - data: low_res_mask, - }; - bboxes.push(bbox); - } - // TODO: - // Filter boxes that touch crop boundaries - // Compress to RLE. - } - } - - let mut bboxes = vec![bboxes]; - // Remove duplicates within this crop. - candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH); - - // TODO: Return to the original image frame. - Ok(bboxes.remove(0)) - } - - pub fn generate_masks( - &self, - img: &Tensor, - points_per_side: usize, - crop_n_layer: usize, - crop_overlap_ratio: f64, - crop_n_points_downscale_factor: usize, - ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> { - let (_c, h, w) = img.dims3()?; - let point_grids = build_all_layer_point_grids( - points_per_side, - crop_n_layer, - crop_n_points_downscale_factor, - ); - let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio); - let mut bboxes = Vec::new(); - for crop_box in crop_boxes.into_iter() { - let layer_idx = crop_box.layer_idx; - let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?; - bboxes.extend(b) - } - // TODO: remove duplicates - Ok(bboxes) - } -} - -// Return the first and last indexes i for which values[i] > 0 -fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> { - let (mut min_i, mut max_i) = (usize::MAX, usize::MIN); - for (i, &s) in values.iter().enumerate() { - if s == 0 { - continue; - } - min_i = usize::min(i, min_i); - max_i = usize::max(i, max_i); - } - if max_i < min_i { - None - } else { - Some((min_i, max_i)) - } -} - -#[derive(Debug)] -struct CropBox { - x0: usize, - y0: usize, - x1: usize, - y1: usize, - layer_idx: usize, -} - -impl CropBox { - fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self { - Self { - x0, - y0, - x1, - y1, - layer_idx, - } - } -} - -fn generate_crop_boxes( - (im_h, im_w): (usize, usize), - n_layers: usize, - overlap_ratio: f64, -) -> Vec<CropBox> { - fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize { - f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize - } - - let short_side = usize::min(im_h, im_w); - - let mut crop_boxes = Vec::new(); - - // Original image. - crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0)); - - for layer_idx in 1..=n_layers { - let n_crops_per_side = 1 << layer_idx; - let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize; - let crop_w = crop_len(im_w, n_crops_per_side, overlap); - let crop_h = crop_len(im_w, n_crops_per_side, overlap); - - for i_x in 0..n_crops_per_side { - let x0 = (crop_w - overlap) * i_x; - for i_y in 0..n_crops_per_side { - let y0 = (crop_h - overlap) * i_y; - let x1 = usize::min(im_w, x0 + crop_w); - let y1 = usize::min(im_h, y0 + crop_h); - crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx)); - } - } - } - - crop_boxes -} - -// Generates a 2D grid of points evenly spaced in [0,1]x[0,1]. -fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> { - let offset = 1f64 / (2 * n_per_side) as f64; - let mut points = Vec::with_capacity(n_per_side * n_per_side); - for i_x in 0..n_per_side { - let x = offset + i_x as f64 / n_per_side as f64; - for i_y in 0..n_per_side { - let y = offset + i_y as f64 / n_per_side as f64; - points.push((x, y)) - } - } - points -} - -fn build_all_layer_point_grids( - n_per_side: usize, - n_layers: usize, - scale_per_layer: usize, -) -> Vec<Vec<(f64, f64)>> { - let mut points_by_layer = Vec::with_capacity(n_layers + 1); - for i in 0..=n_layers { - let n_points = n_per_side / scale_per_layer.pow(i as u32); - points_by_layer.push(build_point_grid(n_points)) - } - points_by_layer -} diff --git a/candle-examples/examples/segment-anything/model_tiny_vit.rs b/candle-examples/examples/segment-anything/model_tiny_vit.rs deleted file mode 100644 index ff076773..00000000 --- a/candle-examples/examples/segment-anything/model_tiny_vit.rs +++ /dev/null @@ -1,633 +0,0 @@ -// Adapted from: -// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py -use candle::{IndexOp, Result, Tensor, D}; -use candle_nn::{Conv2dConfig, Module, VarBuilder}; - -const MBCONV_EXPAND_RATIO: usize = 4; -const MLP_RATIO: usize = 4; -const LOCAL_CONV_SIZE: usize = 3; -const IMG_SIZE: usize = 1024; -const IN_CHANNELS: usize = 3; - -#[derive(Debug)] -struct Conv2dBN { - c: candle_nn::Conv2d, - bn: candle_nn::BatchNorm, - span: tracing::Span, -} - -impl Conv2dBN { - fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> { - let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?; - let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?; - let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn"); - Ok(Self { c, bn, span }) - } -} - -impl Module for Conv2dBN { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - xs.apply(&self.c)?.apply(&self.bn) - } -} - -#[derive(Debug)] -struct PatchEmbed { - conv1: Conv2dBN, - conv2: Conv2dBN, - span: tracing::Span, -} - -impl PatchEmbed { - fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> { - let cfg = candle_nn::Conv2dConfig { - stride: 2, - padding: 1, - ..Default::default() - }; - let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?; - let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?; - let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); - Ok(Self { conv1, conv2, span }) - } -} - -impl Module for PatchEmbed { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2) - } -} - -#[derive(Debug)] -struct MBConv { - conv1: Conv2dBN, - conv2: Conv2dBN, - conv3: Conv2dBN, - span: tracing::Span, -} - -impl MBConv { - fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> { - let hidden = in_ * expand_ratio; - let cfg2 = candle_nn::Conv2dConfig { - padding: 1, - groups: hidden, - ..Default::default() - }; - let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?; - let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?; - let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?; - let span = tracing::span!(tracing::Level::TRACE, "mb-conv"); - Ok(Self { - conv1, - conv2, - conv3, - span, - }) - } -} - -impl Module for MBConv { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let shortcut = xs; - let xs = xs - .apply(&self.conv1)? - .gelu()? - .apply(&self.conv2)? - .gelu()? - .apply(&self.conv3)?; - (xs + shortcut)?.gelu() - } -} - -#[derive(Debug)] -struct PatchMerging { - conv1: Conv2dBN, - conv2: Conv2dBN, - conv3: Conv2dBN, - input_resolution: (usize, usize), - span: tracing::Span, -} - -impl PatchMerging { - fn new( - input_resolution: (usize, usize), - dim: usize, - out: usize, - vb: VarBuilder, - ) -> Result<Self> { - let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 }; - let cfg2 = candle_nn::Conv2dConfig { - padding: 1, - stride, - groups: out, - ..Default::default() - }; - let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?; - let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?; - let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?; - let span = tracing::span!(tracing::Level::TRACE, "patch-merging"); - Ok(Self { - conv1, - conv2, - conv3, - input_resolution, - span, - }) - } -} - -impl Module for PatchMerging { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let xs = if xs.rank() == 3 { - let (h, w) = self.input_resolution; - let b = xs.dim(0)?; - xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))? - } else { - xs.clone() - }; - xs.apply(&self.conv1)? - .gelu()? - .apply(&self.conv2)? - .gelu()? - .apply(&self.conv3)? - .flatten_from(2)? - .transpose(1, 2) - } -} - -#[derive(Debug)] -struct ConvLayer { - blocks: Vec<MBConv>, - downsample: Option<PatchMerging>, - span: tracing::Span, -} - -impl ConvLayer { - fn new( - dim: usize, - out: usize, - input_resolution: (usize, usize), - depth: usize, - downsample: bool, - conv_expand_ratio: usize, - vb: VarBuilder, - ) -> Result<Self> { - let vb_b = vb.pp("blocks"); - let mut blocks = Vec::with_capacity(depth); - for index in 0..depth { - let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?; - blocks.push(block) - } - let downsample = if downsample { - let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?; - Some(downsample) - } else { - None - }; - let span = tracing::span!(tracing::Level::TRACE, "conv-layer"); - Ok(Self { - blocks, - downsample, - span, - }) - } -} - -impl Module for ConvLayer { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let mut xs = xs.clone(); - for block in self.blocks.iter() { - xs = block.forward(&xs)? - } - match &self.downsample { - None => Ok(xs), - Some(downsample) => downsample.forward(&xs), - } - } -} - -#[derive(Debug)] -struct Mlp { - norm: candle_nn::LayerNorm, - fc1: crate::Linear, - fc2: crate::Linear, - span: tracing::Span, -} - -impl Mlp { - fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> { - let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?; - let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?; - let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?; - let span = tracing::span!(tracing::Level::TRACE, "mlp"); - Ok(Self { - norm, - fc1, - fc2, - span, - }) - } -} - -impl Module for Mlp { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - xs.apply(&self.norm)? - .apply(&self.fc1)? - .gelu()? - .apply(&self.fc2) - } -} - -#[derive(Debug)] -struct Attention { - norm: candle_nn::LayerNorm, - qkv: crate::Linear, - proj: crate::Linear, - ab: Tensor, - key_dim: usize, - num_heads: usize, - d: usize, - dh: usize, - scale: f64, - span: tracing::Span, - span_matmul: tracing::Span, - span_softmax: tracing::Span, -} - -impl Attention { - fn new( - dim: usize, - key_dim: usize, - num_heads: usize, - attn_ratio: usize, - resolution: (usize, usize), - vb: VarBuilder, - ) -> Result<Self> { - let d = attn_ratio * key_dim; - let dh = d * num_heads; - let nh_kd = key_dim * num_heads; - let h = dh + nh_kd * 2; - let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?; - let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?; - let proj = crate::linear(vb.pp("proj"), dh, dim, true)?; - - let points = (0..resolution.0) - .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64))) - .collect::<Vec<_>>(); - let mut idxs = Vec::with_capacity(points.len() * points.len()); - let mut attention_offsets = std::collections::HashMap::new(); - for &(x1, y1) in points.iter() { - for &(x2, y2) in points.iter() { - let offset = ((x2 - x1).abs(), (y2 - y1).abs()); - let l = attention_offsets.len(); - let idx = attention_offsets.entry(offset).or_insert(l); - idxs.push(*idx as u32) - } - } - let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?; - let idxs = Tensor::new(idxs, attention_biases.device())?; - let ab = - attention_biases - .index_select(&idxs, 1)? - .reshape(((), points.len(), points.len()))?; - let span = tracing::span!(tracing::Level::TRACE, "attention"); - let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); - let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); - Ok(Self { - norm, - qkv, - proj, - ab, - key_dim, - num_heads, - d, - dh, - scale: 1f64 / (key_dim as f64).sqrt(), - span, - span_matmul, - span_softmax, - }) - } -} - -impl Module for Attention { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let (b, n, _) = xs.dims3()?; - let xs = xs.apply(&self.norm)?; - let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?; - let q = qkv - .narrow(D::Minus1, 0, self.key_dim)? - .permute((0, 2, 1, 3))? - .contiguous()?; - let k = qkv - .narrow(D::Minus1, self.key_dim, self.key_dim)? - .permute((0, 2, 1, 3))? - .contiguous()?; - let v = qkv - .narrow(D::Minus1, 2 * self.key_dim, self.d)? - .permute((0, 2, 1, 3))? - .contiguous()?; - let attn = { - let _enter = self.span_matmul.enter(); - (q.matmul(&k.t()?)? * self.scale)? - }; - let attn = attn.broadcast_add(&self.ab)?; - let attn = { - let _enter = self.span_softmax.enter(); - candle_nn::ops::softmax_last_dim(&attn)? - }; - let attn = { - let _enter = self.span_matmul.enter(); - attn.matmul(&v)? - }; - attn.transpose(1, 2)? - .reshape((b, n, self.dh))? - .apply(&self.proj) - } -} - -#[derive(Debug)] -struct TinyViTBlock { - attn: Attention, - local_conv: Conv2dBN, - mlp: Mlp, - window_size: usize, - input_resolution: (usize, usize), - span: tracing::Span, -} - -impl TinyViTBlock { - fn new( - dim: usize, - input_resolution: (usize, usize), - num_heads: usize, - window_size: usize, - vb: VarBuilder, - ) -> Result<Self> { - let head_dim = dim / num_heads; - let attn = Attention::new( - dim, - head_dim, - num_heads, - 1, - (window_size, window_size), - vb.pp("attn"), - )?; - let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?; - let cfg = candle_nn::Conv2dConfig { - padding: LOCAL_CONV_SIZE / 2, - groups: dim, - ..Default::default() - }; - let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?; - let span = tracing::span!(tracing::Level::TRACE, "attention"); - Ok(Self { - attn, - local_conv, - mlp, - window_size, - input_resolution, - span, - }) - } -} - -impl Module for TinyViTBlock { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let (h, w) = self.input_resolution; - let (b, l, c) = xs.dims3()?; - let res_x = xs; - let xs = if h == self.window_size && w == self.window_size { - self.attn.forward(xs)? - } else { - let xs = xs.reshape((b, h, w, c))?; - let pad_b = (self.window_size - h % self.window_size) % self.window_size; - let pad_r = (self.window_size - w % self.window_size) % self.window_size; - - let xs = if pad_b > 0 { - xs.pad_with_zeros(1, 0, pad_b)? - } else { - xs - }; - let xs = if pad_r > 0 { - xs.pad_with_zeros(2, 0, pad_r)? - } else { - xs - }; - let (p_h, p_w) = (h + pad_b, w + pad_r); - let n_h = p_h / self.window_size; - let n_w = p_w / self.window_size; - let xs = xs - .reshape((b, n_h, self.window_size, n_w, self.window_size, c))? - .transpose(2, 3)? - .reshape((b * n_h * n_w, self.window_size * self.window_size, c))?; - let xs = self.attn.forward(&xs)?; - let xs = xs - .reshape((b, n_h, n_w, self.window_size, self.window_size, c))? - .transpose(2, 3)? - .reshape((b, p_h, p_w, c))?; - let xs = if pad_r > 0 { - xs.i((.., .., ..w))?.contiguous()? - } else { - xs - }; - let xs = if pad_b > 0 { - xs.i((.., ..h, ..))?.contiguous()? - } else { - xs - }; - xs.reshape((b, l, c))? - }; - let xs = (xs + res_x)?; - let xs = xs - .transpose(1, 2)? - .reshape((b, c, h, w))? - .apply(&self.local_conv)? - .reshape((b, c, l))? - .transpose(1, 2)?; - &xs + self.mlp.forward(&xs)? - } -} - -#[derive(Debug)] -struct BasicLayer { - blocks: Vec<TinyViTBlock>, - downsample: Option<PatchMerging>, - span: tracing::Span, -} - -impl BasicLayer { - #[allow(clippy::too_many_arguments)] - fn new( - dim: usize, - input_resolution: (usize, usize), - depth: usize, - num_heads: usize, - window_size: usize, - downsample: bool, - out: usize, - vb: VarBuilder, - ) -> Result<Self> { - let vb_b = vb.pp("blocks"); - let mut blocks = Vec::with_capacity(depth); - for index in 0..depth { - let block = TinyViTBlock::new( - dim, - input_resolution, - num_heads, - window_size, - vb_b.pp(index), - )?; - blocks.push(block) - } - let downsample = if downsample { - let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?; - Some(downsample) - } else { - None - }; - let span = tracing::span!(tracing::Level::TRACE, "basic-layer"); - Ok(Self { - blocks, - downsample, - span, - }) - } -} - -impl Module for BasicLayer { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let mut xs = xs.clone(); - for block in self.blocks.iter() { - xs = block.forward(&xs)? - } - match &self.downsample { - None => Ok(xs), - Some(downsample) => downsample.forward(&xs), - } - } -} - -#[derive(Debug)] -pub struct TinyViT { - patch_embed: PatchEmbed, - layer0: ConvLayer, - layers: Vec<BasicLayer>, - // norm_head: candle_nn::LayerNorm, - // head: candle_nn::Linear, - neck_conv1: candle_nn::Conv2d, - neck_ln1: crate::LayerNorm2d, - neck_conv2: candle_nn::Conv2d, - neck_ln2: crate::LayerNorm2d, - span: tracing::Span, - span_neck: tracing::Span, -} - -impl TinyViT { - pub fn new( - embed_dims: &[usize], - depths: &[usize], - num_heads: &[usize], - window_sizes: &[usize], - _num_classes: usize, - vb: VarBuilder, - ) -> Result<Self> { - let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?; - let patches_resolution = IMG_SIZE / 4; - - let vb_l = vb.pp("layers"); - let layer0 = ConvLayer::new( - /* dim */ embed_dims[0], - /* out */ embed_dims[1], - /* input_resolution */ (patches_resolution, patches_resolution), - /* depth */ depths[0], - /* downsample */ true, - /* conv_expand_ratio */ MBCONV_EXPAND_RATIO, - vb_l.pp(0), - )?; - - let num_layers = embed_dims.len(); - let mut layers = Vec::with_capacity(num_layers - 1); - for i_layer in 1..num_layers { - let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2)); - let layer = BasicLayer::new( - /* dim */ embed_dims[i_layer], - /* input_resolution */ (patches_resolution, patches_resolution), - /* depth */ depths[i_layer], - /* num_heads */ num_heads[i_layer], - /* window_size */ window_sizes[i_layer], - /* downsample */ i_layer < num_layers - 1, - /* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)], - vb_l.pp(i_layer), - )?; - layers.push(layer) - } - - let last_embed_dim = embed_dims[embed_dims.len() - 1]; - // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?; - // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?; - let neck_conv1 = - candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?; - let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?; - let cfg = candle_nn::Conv2dConfig { - padding: 1, - ..Default::default() - }; - let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?; - let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?; - - let span = tracing::span!(tracing::Level::TRACE, "tiny-vit"); - let span_neck = tracing::span!(tracing::Level::TRACE, "neck"); - Ok(Self { - patch_embed, - layer0, - layers, - neck_conv1, - neck_ln1, - neck_conv2, - neck_ln2, - span, - span_neck, - }) - } -} - -impl Module for TinyViT { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let xs = self.patch_embed.forward(xs)?; - let mut xs = self.layer0.forward(&xs)?; - for layer in self.layers.iter() { - xs = layer.forward(&xs)? - } - let (b, _, c) = xs.dims3()?; - let _enter = self.span_neck.enter(); - xs.reshape((b, 64, 64, c))? - .permute((0, 3, 1, 2))? - .apply(&self.neck_conv1)? - .apply(&self.neck_ln1)? - .apply(&self.neck_conv2)? - .apply(&self.neck_ln2) - } -} - -pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> { - TinyViT::new( - /* embed_dims */ &[64, 128, 160, 320], - /* depths */ &[2, 2, 6, 2], - /* num_heads */ &[2, 4, 5, 10], - /* window_sizes */ &[7, 7, 14, 7], - /* num_classes */ 1000, - vb, - ) -} diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs deleted file mode 100644 index e12aac08..00000000 --- a/candle-examples/examples/segment-anything/model_transformer.rs +++ /dev/null @@ -1,221 +0,0 @@ -use candle::{Result, Tensor}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; - -#[derive(Debug)] -struct Attention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - out_proj: Linear, - num_heads: usize, -} - -impl Attention { - fn new( - embedding_dim: usize, - num_heads: usize, - downsample_rate: usize, - vb: VarBuilder, - ) -> Result<Self> { - let internal_dim = embedding_dim / downsample_rate; - let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?; - let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?; - let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?; - let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?; - Ok(Self { - q_proj, - k_proj, - v_proj, - out_proj, - num_heads, - }) - } - - fn separate_heads(&self, x: &Tensor) -> Result<Tensor> { - let (b, n, c) = x.dims3()?; - x.reshape((b, n, self.num_heads, c / self.num_heads))? - .transpose(1, 2)? - .contiguous() - } - - fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> { - let (b, n_heads, n_tokens, c_per_head) = x.dims4()?; - x.transpose(1, 2)? - .reshape((b, n_tokens, n_heads * c_per_head)) - } - - fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { - let q = self.q_proj.forward(&q.contiguous()?)?; - let k = self.k_proj.forward(&k.contiguous()?)?; - let v = self.v_proj.forward(&v.contiguous()?)?; - - let q = self.separate_heads(&q)?; - let k = self.separate_heads(&k)?; - let v = self.separate_heads(&v)?; - - let (_, _, _, c_per_head) = q.dims4()?; - let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?; - let attn = candle_nn::ops::softmax_last_dim(&attn)?; - - let out = attn.matmul(&v)?; - self.recombine_heads(&out)?.apply(&self.out_proj) - } -} - -#[derive(Debug)] -struct TwoWayAttentionBlock { - self_attn: Attention, - norm1: LayerNorm, - cross_attn_token_to_image: Attention, - norm2: LayerNorm, - mlp: crate::MlpBlock, - norm3: LayerNorm, - norm4: LayerNorm, - cross_attn_image_to_token: Attention, - skip_first_layer_pe: bool, -} - -impl TwoWayAttentionBlock { - fn new( - embedding_dim: usize, - num_heads: usize, - mlp_dim: usize, - skip_first_layer_pe: bool, - vb: VarBuilder, - ) -> Result<Self> { - let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?; - let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?; - let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?; - let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?; - let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?; - let cross_attn_token_to_image = Attention::new( - embedding_dim, - num_heads, - 2, - vb.pp("cross_attn_token_to_image"), - )?; - let cross_attn_image_to_token = Attention::new( - embedding_dim, - num_heads, - 2, - vb.pp("cross_attn_image_to_token"), - )?; - let mlp = crate::MlpBlock::new( - embedding_dim, - mlp_dim, - candle_nn::Activation::Relu, - vb.pp("mlp"), - )?; - Ok(Self { - self_attn, - norm1, - cross_attn_image_to_token, - norm2, - mlp, - norm3, - norm4, - cross_attn_token_to_image, - skip_first_layer_pe, - }) - } - - fn forward( - &self, - queries: &Tensor, - keys: &Tensor, - query_pe: &Tensor, - key_pe: &Tensor, - ) -> Result<(Tensor, Tensor)> { - // Self attention block - let queries = if self.skip_first_layer_pe { - self.self_attn.forward(queries, queries, queries)? - } else { - let q = (queries + query_pe)?; - let attn_out = self.self_attn.forward(&q, &q, queries)?; - (queries + attn_out)? - }; - let queries = self.norm1.forward(&queries)?; - - // Cross attention block, tokens attending to image embedding - let q = (&queries + query_pe)?; - let k = (keys + key_pe)?; - let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?; - let queries = (&queries + attn_out)?; - let queries = self.norm2.forward(&queries)?; - - // MLP block - let mlp_out = self.mlp.forward(&queries); - let queries = (queries + mlp_out)?; - let queries = self.norm3.forward(&queries)?; - - // Cross attention block, image embedding attending to tokens - let q = (&queries + query_pe)?; - let k = (keys + key_pe)?; - let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?; - let keys = (keys + attn_out)?; - let keys = self.norm4.forward(&keys)?; - - Ok((queries, keys)) - } -} - -#[derive(Debug)] -pub struct TwoWayTransformer { - layers: Vec<TwoWayAttentionBlock>, - final_attn_token_to_image: Attention, - norm_final_attn: LayerNorm, -} - -impl TwoWayTransformer { - pub fn new( - depth: usize, - embedding_dim: usize, - num_heads: usize, - mlp_dim: usize, - vb: VarBuilder, - ) -> Result<Self> { - let vb_l = vb.pp("layers"); - let mut layers = Vec::with_capacity(depth); - for i in 0..depth { - let layer = - TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?; - layers.push(layer) - } - let final_attn_token_to_image = Attention::new( - embedding_dim, - num_heads, - 2, - vb.pp("final_attn_token_to_image"), - )?; - let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?; - Ok(Self { - layers, - final_attn_token_to_image, - norm_final_attn, - }) - } - - pub fn forward( - &self, - image_embedding: &Tensor, - image_pe: &Tensor, - point_embedding: &Tensor, - ) -> Result<(Tensor, Tensor)> { - let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?; - let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?; - - let mut queries = point_embedding.clone(); - let mut keys = image_embedding; - - for layer in self.layers.iter() { - (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)? - } - - let q = (&queries + point_embedding)?; - let k = (&keys + image_pe)?; - let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?; - let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?; - - Ok((queries, keys)) - } -} diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs index 20021b45..ecf75bdf 100644 --- a/candle-examples/examples/yolo-v3/main.rs +++ b/candle-examples/examples/yolo-v3/main.rs @@ -4,7 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle_examples::object_detection::{non_maximum_suppression, Bbox}; +use candle_transformers::object_detection::{non_maximum_suppression, Bbox}; mod darknet; use anyhow::Result; diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index 2017b5be..d48bac35 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -8,8 +8,8 @@ mod model; use model::{Multiples, YoloV8, YoloV8Pose}; use candle::{DType, Device, IndexOp, Result, Tensor}; -use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use candle_nn::{Module, VarBuilder}; +use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use clap::{Parser, ValueEnum}; use image::DynamicImage; |