diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-10 10:20:18 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-10 10:20:18 +0100 |
commit | 35f72514f59b3fa4bd321e3e88a75f5b43cf060f (patch) | |
tree | 37dd25098bcf16293744758268a0486337d18431 | |
parent | d3f05eae8c4f2df186b46e433be101ac39fceca5 (diff) | |
download | candle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.tar.gz candle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.tar.bz2 candle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.zip |
Move more models to candle-transformers (#796)
* Move dinov2.
* Move efficientnet.
* Move the quantized llama model.
* Move segment-anything.
21 files changed, 773 insertions, 759 deletions
diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs index e80c81e2..d3adb37c 100644 --- a/candle-examples/examples/dinov2/main.rs +++ b/candle-examples/examples/dinov2/main.rs @@ -9,285 +9,10 @@ extern crate accelerate_src; use clap::Parser; -use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::dinov2; -const IMG_SIZE: usize = 518; -const PATCH_SIZE: usize = 14; -const NUM_CLASSES: usize = 1000; - -fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { - if bias { - candle_nn::linear(in_dim, out_dim, vb) - } else { - candle_nn::linear_no_bias(in_dim, out_dim, vb) - } -} - -#[derive(Debug)] -struct Attention { - qkv: Linear, - proj: Linear, - num_heads: usize, - scale: f64, -} - -impl Attention { - fn new( - vb: VarBuilder, - dim: usize, - num_heads: usize, - qkv_bias: bool, - proj_bias: bool, - ) -> Result<Self> { - let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; - let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; - let scale = 1. / ((dim / num_heads) as f64).sqrt(); - Ok(Self { - qkv, - proj, - num_heads, - scale, - }) - } -} - -impl Module for Attention { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let (b, n, c) = xs.dims3()?; - let qkv = self - .qkv - .forward(xs)? - .reshape((b, n, 3, self.num_heads, c / self.num_heads))? - .transpose(1, 2)? // 02134 - .transpose(0, 1)? // 20134 - .transpose(2, 3)?; // 20314 - let q = (qkv.i(0)? * self.scale)?; - let k = qkv.i(1)?; - let v = qkv.i(2)?; - let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; - let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; - self.proj.forward(&attn) - } -} - -#[derive(Debug)] -struct LayerScale { - gamma: Tensor, -} - -impl LayerScale { - fn new(vb: VarBuilder, dim: usize) -> Result<Self> { - let gamma = vb.get(dim, "gamma")?; - Ok(Self { gamma }) - } -} - -impl Module for LayerScale { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - xs.broadcast_mul(&self.gamma) - } -} - -#[derive(Debug)] -struct Mlp { - fc1: Linear, - fc2: Linear, -} - -impl Mlp { - fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> { - let out_features = in_features; - let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?; - let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?; - Ok(Self { fc1, fc2 }) - } -} - -impl Module for Mlp { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let xs = self.fc1.forward(xs)?.gelu()?; - self.fc2.forward(&xs) - } -} - -#[derive(Debug)] -struct Block { - norm1: LayerNorm, - attn: Attention, - ls1: LayerScale, - norm2: LayerNorm, - mlp: Mlp, - ls2: LayerScale, -} - -impl Block { - fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> { - let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; - let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; - let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; - let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; - let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; - let ls2 = LayerScale::new(vb.pp("ls2"), dim)?; - Ok(Self { - norm1, - attn, - ls1, - norm2, - mlp, - ls2, - }) - } -} - -impl Module for Block { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let residual = xs; - let xs = self - .ls1 - .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?; - let xs = (xs + residual)?; - let residual = &xs; - let xs = self - .ls2 - .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?; - xs + residual - } -} - -#[derive(Debug)] -struct PatchEmbed { - proj: candle_nn::Conv2d, - patch_size: (usize, usize), - num_patches: usize, -} - -impl PatchEmbed { - fn new( - vb: VarBuilder, - img_size: usize, - patch_size: usize, - in_chans: usize, - embed_dim: usize, - ) -> Result<Self> { - let config = candle_nn::Conv2dConfig { - stride: patch_size, - ..Default::default() - }; - let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?; - let num_patches = (img_size / patch_size) * (img_size / patch_size); - Ok(Self { - proj, - patch_size: (patch_size, patch_size), - num_patches, - }) - } -} - -impl Module for PatchEmbed { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let (_b, _c, h, w) = xs.dims4()?; - let (patch_h, patch_w) = self.patch_size; - if (h % patch_h) != 0 { - candle::bail!("image height {h} is not a multiple of patch height {patch_h}") - } - if (w % patch_w) != 0 { - candle::bail!("image width {w} is not a multiple of patch width {patch_w}") - } - let xs = self.proj.forward(xs)?; - let (b, c, h, w) = xs.dims4()?; - // flatten embeddings. - xs.reshape((b, c, h * w))?.transpose(1, 2) - } -} - -#[derive(Debug)] -pub struct DinoVisionTransformer { - patch_embed: PatchEmbed, - cls_token: Tensor, - pos_embed: Tensor, - blocks: Vec<Block>, - norm: LayerNorm, - head: Linear, -} - -impl DinoVisionTransformer { - pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> { - let patch_embed = - PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?; - let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; - let num_tokens = 1; - let pos_embed = vb.get( - (1, patch_embed.num_patches + num_tokens, embed_dim), - "pos_embed", - )?; - let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?; - let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?; - let vb_b = vb.pp("blocks"); - let blocks = (0..depth) - .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) - .collect::<Result<Vec<_>>>()?; - Ok(Self { - patch_embed, - cls_token, - pos_embed, - blocks, - norm, - head, - }) - } - - fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> { - let npatch = xs.dim(1)? - 1; - let n = self.pos_embed.dim(1)? - 1; - let sqrt_n = (n as f64).sqrt(); - if npatch == n && w == h { - return Ok(xs.clone()); - } - let class_pos_embed = self.pos_embed.i((.., ..1))?; - let patch_pos_embed = self.pos_embed.i((.., 1..))?; - let dim = xs.dim(D::Minus1)?; - let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); - let patch_pos_embed = patch_pos_embed - .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? - .transpose(2, 3)? - .transpose(1, 2)?; - // This uses bicubic interpolation in the original implementation. - let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; - let el_count = patch_pos_embed.shape().elem_count(); - let patch_pos_embed = - patch_pos_embed - .transpose(1, 2)? - .transpose(2, 3)? - .reshape((1, el_count / dim, dim))?; - Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1) - } - - fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> { - let (_b, _nc, w, h) = xs.dims4()?; - let xs = self.patch_embed.forward(xs)?; - let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; - &xs + &self.interpolate_pos_encoding(&xs, w, h)? - } -} - -impl Module for DinoVisionTransformer { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let mut xs = self.prepare_tokens_with_mask(xs)?; - for blk in self.blocks.iter() { - xs = blk.forward(&xs)? - } - let xs = self.norm.forward(&xs)?; - let xs_norm_clstoken = xs.i((.., 0))?; - let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?; - let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?; - self.head.forward(&xs) - } -} - -pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> { - DinoVisionTransformer::new(vb, 12, 384, 6) -} #[derive(Parser)] struct Args { #[arg(long)] @@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> { let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); - let model = vit_small(vb)?; + let model = dinov2::vit_small(vb)?; println!("model built"); let logits = model.forward(&image.unsqueeze(0)?)?; let prs = candle_nn::ops::softmax(&logits, D::Minus1)? diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs index cbe2c90a..1e45e301 100644 --- a/candle-examples/examples/efficientnet/main.rs +++ b/candle-examples/examples/efficientnet/main.rs @@ -8,340 +8,11 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig}; use clap::{Parser, ValueEnum}; -use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn as nn; -use nn::{Module, VarBuilder}; - -// Based on the Python version from torchvision. -// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47 -#[derive(Debug, Clone, Copy)] -pub struct MBConvConfig { - expand_ratio: f64, - kernel: usize, - stride: usize, - input_channels: usize, - out_channels: usize, - num_layers: usize, -} - -fn make_divisible(v: f64, divisor: usize) -> usize { - let min_value = divisor; - let new_v = usize::max( - min_value, - (v + divisor as f64 * 0.5) as usize / divisor * divisor, - ); - if (new_v as f64) < 0.9 * v { - new_v + divisor - } else { - new_v - } -} - -fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> { - let bneck_conf = |e, k, s, i, o, n| { - let input_channels = make_divisible(i as f64 * width_mult, 8); - let out_channels = make_divisible(o as f64 * width_mult, 8); - let num_layers = (n as f64 * depth_mult).ceil() as usize; - MBConvConfig { - expand_ratio: e, - kernel: k, - stride: s, - input_channels, - out_channels, - num_layers, - } - }; - vec![ - bneck_conf(1., 3, 1, 32, 16, 1), - bneck_conf(6., 3, 2, 16, 24, 2), - bneck_conf(6., 5, 2, 24, 40, 2), - bneck_conf(6., 3, 2, 40, 80, 3), - bneck_conf(6., 5, 1, 80, 112, 3), - bneck_conf(6., 5, 2, 112, 192, 4), - bneck_conf(6., 3, 1, 192, 320, 1), - ] -} - -impl MBConvConfig { - fn b0() -> Vec<Self> { - bneck_confs(1.0, 1.0) - } - fn b1() -> Vec<Self> { - bneck_confs(1.0, 1.1) - } - fn b2() -> Vec<Self> { - bneck_confs(1.1, 1.2) - } - fn b3() -> Vec<Self> { - bneck_confs(1.2, 1.4) - } - fn b4() -> Vec<Self> { - bneck_confs(1.4, 1.8) - } - fn b5() -> Vec<Self> { - bneck_confs(1.6, 2.2) - } - fn b6() -> Vec<Self> { - bneck_confs(1.8, 2.6) - } - fn b7() -> Vec<Self> { - bneck_confs(2.0, 3.1) - } -} - -/// Conv2D with same padding. -#[derive(Debug)] -struct Conv2DSame { - conv2d: nn::Conv2d, - s: usize, - k: usize, -} - -impl Conv2DSame { - fn new( - vb: VarBuilder, - i: usize, - o: usize, - k: usize, - stride: usize, - groups: usize, - bias: bool, - ) -> Result<Self> { - let conv_config = nn::Conv2dConfig { - stride, - groups, - ..Default::default() - }; - let conv2d = if bias { - nn::conv2d(i, o, k, conv_config, vb)? - } else { - nn::conv2d_no_bias(i, o, k, conv_config, vb)? - }; - Ok(Self { - conv2d, - s: stride, - k, - }) - } -} - -impl Module for Conv2DSame { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let s = self.s; - let k = self.k; - let (_, _, ih, iw) = xs.dims4()?; - let oh = (ih + s - 1) / s; - let ow = (iw + s - 1) / s; - let pad_h = usize::max((oh - 1) * s + k - ih, 0); - let pad_w = usize::max((ow - 1) * s + k - iw, 0); - if pad_h > 0 || pad_w > 0 { - let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?; - let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?; - self.conv2d.forward(&xs) - } else { - self.conv2d.forward(xs) - } - } -} - -#[derive(Debug)] -struct ConvNormActivation { - conv2d: Conv2DSame, - bn2d: nn::BatchNorm, - activation: bool, -} - -impl ConvNormActivation { - fn new( - vb: VarBuilder, - i: usize, - o: usize, - k: usize, - stride: usize, - groups: usize, - ) -> Result<Self> { - let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?; - let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?; - Ok(Self { - conv2d, - bn2d, - activation: true, - }) - } - - fn no_activation(self) -> Self { - Self { - activation: false, - ..self - } - } -} - -impl Module for ConvNormActivation { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let xs = self.conv2d.forward(xs)?; - let xs = self.bn2d.forward(&xs)?; - if self.activation { - swish(&xs) - } else { - Ok(xs) - } - } -} - -#[derive(Debug)] -struct SqueezeExcitation { - fc1: Conv2DSame, - fc2: Conv2DSame, -} - -impl SqueezeExcitation { - fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> { - let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?; - let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?; - Ok(Self { fc1, fc2 }) - } -} - -impl Module for SqueezeExcitation { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let residual = xs; - // equivalent to adaptive_avg_pool2d([1, 1]) - let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; - let xs = self.fc1.forward(&xs)?; - let xs = swish(&xs)?; - let xs = self.fc2.forward(&xs)?; - let xs = nn::ops::sigmoid(&xs)?; - residual.broadcast_mul(&xs) - } -} - -#[derive(Debug)] -struct MBConv { - expand_cna: Option<ConvNormActivation>, - depthwise_cna: ConvNormActivation, - squeeze_excitation: SqueezeExcitation, - project_cna: ConvNormActivation, - config: MBConvConfig, -} - -impl MBConv { - fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> { - let vb = vb.pp("block"); - let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8); - let expand_cna = if exp != c.input_channels { - Some(ConvNormActivation::new( - vb.pp("0"), - c.input_channels, - exp, - 1, - 1, - 1, - )?) - } else { - None - }; - let start_index = if expand_cna.is_some() { 1 } else { 0 }; - let depthwise_cna = - ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?; - let squeeze_channels = usize::max(1, c.input_channels / 4); - let squeeze_excitation = - SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?; - let project_cna = - ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)? - .no_activation(); - Ok(Self { - expand_cna, - depthwise_cna, - squeeze_excitation, - project_cna, - config: c, - }) - } -} - -impl Module for MBConv { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let use_res_connect = - self.config.stride == 1 && self.config.input_channels == self.config.out_channels; - let ys = match &self.expand_cna { - Some(expand_cna) => expand_cna.forward(xs)?, - None => xs.clone(), - }; - let ys = self.depthwise_cna.forward(&ys)?; - let ys = self.squeeze_excitation.forward(&ys)?; - let ys = self.project_cna.forward(&ys)?; - if use_res_connect { - ys + xs - } else { - Ok(ys) - } - } -} - -fn swish(s: &Tensor) -> Result<Tensor> { - s * nn::ops::sigmoid(s)? -} - -#[derive(Debug)] -struct EfficientNet { - init_cna: ConvNormActivation, - blocks: Vec<MBConv>, - final_cna: ConvNormActivation, - classifier: nn::Linear, -} - -impl EfficientNet { - fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> { - let f_p = p.pp("features"); - let first_in_c = configs[0].input_channels; - let last_out_c = configs.last().unwrap().out_channels; - let final_out_c = 4 * last_out_c; - let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; - let nconfigs = configs.len(); - let mut blocks = vec![]; - for (index, cnf) in configs.into_iter().enumerate() { - let f_p = f_p.pp(index + 1); - for r_index in 0..cnf.num_layers { - let cnf = if r_index == 0 { - cnf - } else { - MBConvConfig { - input_channels: cnf.out_channels, - stride: 1, - ..cnf - } - }; - blocks.push(MBConv::new(f_p.pp(r_index), cnf)?) - } - } - let final_cna = - ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?; - let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?; - Ok(Self { - init_cna, - blocks, - final_cna, - classifier, - }) - } -} - -impl Module for EfficientNet { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let mut xs = self.init_cna.forward(xs)?; - for block in self.blocks.iter() { - xs = block.forward(&xs)? - } - let xs = self.final_cna.forward(&xs)?; - // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1) - let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?; - self.classifier.forward(&xs) - } -} - #[derive(Clone, Copy, Debug, ValueEnum)] enum Which { B0, diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index a3f98d8e..c8179d33 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file}; use candle::{Device, Tensor}; use candle_transformers::generation::LogitsProcessor; -mod model; +use candle_transformers::models::quantized_llama as model; use model::ModelWeights; const DEFAULT_PROMPT: &str = "My favorite theorem is "; diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 9ce2f158..21ba0415 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -7,108 +7,11 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -pub mod model_image_encoder; -pub mod model_mask_decoder; -pub mod model_prompt_encoder; -pub mod model_sam; -pub mod model_tiny_vit; -pub mod model_transformer; - -use candle::{DType, Result, Tensor}; -use candle_nn::{Module, VarBuilder}; +use candle::DType; +use candle_nn::VarBuilder; +use candle_transformers::models::segment_anything::sam; use clap::Parser; -pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { - let inner = if bias { - candle_nn::linear(in_dim, out_dim, vb)? - } else { - candle_nn::linear_no_bias(in_dim, out_dim, vb)? - }; - let span = tracing::span!(tracing::Level::TRACE, "linear"); - Ok(Linear { inner, span }) -} - -#[derive(Debug)] -pub struct LayerNorm2d { - weight: Tensor, - bias: Tensor, - num_channels: usize, - eps: f64, -} - -impl LayerNorm2d { - pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> { - let weight = vb.get(num_channels, "weight")?; - let bias = vb.get(num_channels, "bias")?; - Ok(Self { - weight, - bias, - num_channels, - eps, - }) - } -} - -impl Module for LayerNorm2d { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let u = xs.mean_keepdim(1)?; - let xs = xs.broadcast_sub(&u)?; - let s = xs.sqr()?.mean_keepdim(1)?; - let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; - xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? - .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) - } -} - -#[derive(Debug)] -pub struct MlpBlock { - lin1: Linear, - lin2: Linear, - activation: candle_nn::Activation, - span: tracing::Span, -} - -impl MlpBlock { - pub fn new( - embedding_dim: usize, - mlp_dim: usize, - activation: candle_nn::Activation, - vb: VarBuilder, - ) -> Result<Self> { - let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; - let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; - let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); - Ok(Self { - lin1, - lin2, - activation, - span, - }) - } -} - -impl Module for MlpBlock { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - xs.apply(&self.lin1)? - .apply(&self.activation)? - .apply(&self.lin2) - } -} - -#[derive(Debug)] -pub struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Parser)] struct Args { #[arg(long)] @@ -173,7 +76,7 @@ pub fn main() -> anyhow::Result<()> { let (_c, h, w) = image.dims3()?; (image, h, w) } else { - let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?; + let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?; (image.to_device(&device)?, h, w) }; println!("loaded image {image:?}"); @@ -195,9 +98,9 @@ pub fn main() -> anyhow::Result<()> { let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let sam = if args.use_tiny { - model_sam::Sam::new_tiny(vb)? // tiny vit_t + sam::Sam::new_tiny(vb)? // tiny vit_t } else { - model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b + sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b }; if args.generate_masks { diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs index 20021b45..ecf75bdf 100644 --- a/candle-examples/examples/yolo-v3/main.rs +++ b/candle-examples/examples/yolo-v3/main.rs @@ -4,7 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle_examples::object_detection::{non_maximum_suppression, Bbox}; +use candle_transformers::object_detection::{non_maximum_suppression, Bbox}; mod darknet; use anyhow::Result; diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index 2017b5be..d48bac35 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -8,8 +8,8 @@ mod model; use model::{Multiples, YoloV8, YoloV8Pose}; use candle::{DType, Device, IndexOp, Result, Tensor}; -use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use candle_nn::{Module, VarBuilder}; +use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use clap::{Parser, ValueEnum}; use image::DynamicImage; diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index c14b2d6b..5e0b44fb 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -1,6 +1,5 @@ pub mod coco_classes; pub mod imagenet; -pub mod object_detection; use candle::{Device, Result, Tensor}; diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 6b2087cb..86caf918 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -16,6 +16,7 @@ candle-nn = { path = "../candle-nn", version = "0.2.1" } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } rand = { workspace = true } +rayon = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } tracing = { workspace = true } diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs index a8890dc8..b83e5056 100644 --- a/candle-transformers/src/lib.rs +++ b/candle-transformers/src/lib.rs @@ -1,4 +1,5 @@ pub mod generation; pub mod models; +pub mod object_detection; pub mod pipelines; pub mod utils; diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs new file mode 100644 index 00000000..0edc8494 --- /dev/null +++ b/candle-transformers/src/models/dinov2.rs @@ -0,0 +1,279 @@ +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +const IMG_SIZE: usize = 518; +const PATCH_SIZE: usize = 14; +const NUM_CLASSES: usize = 1000; + +fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { + if bias { + candle_nn::linear(in_dim, out_dim, vb) + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb) + } +} + +#[derive(Debug)] +struct Attention { + qkv: Linear, + proj: Linear, + num_heads: usize, + scale: f64, +} + +impl Attention { + fn new( + vb: VarBuilder, + dim: usize, + num_heads: usize, + qkv_bias: bool, + proj_bias: bool, + ) -> Result<Self> { + let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; + let scale = 1. / ((dim / num_heads) as f64).sqrt(); + Ok(Self { + qkv, + proj, + num_heads, + scale, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (b, n, c) = xs.dims3()?; + let qkv = self + .qkv + .forward(xs)? + .reshape((b, n, 3, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? // 02134 + .transpose(0, 1)? // 20134 + .transpose(2, 3)?; // 20314 + let q = (qkv.i(0)? * self.scale)?; + let k = qkv.i(1)?; + let v = qkv.i(2)?; + let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; + let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; + self.proj.forward(&attn) + } +} + +#[derive(Debug)] +struct LayerScale { + gamma: Tensor, +} + +impl LayerScale { + fn new(vb: VarBuilder, dim: usize) -> Result<Self> { + let gamma = vb.get(dim, "gamma")?; + Ok(Self { gamma }) + } +} + +impl Module for LayerScale { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.broadcast_mul(&self.gamma) + } +} + +#[derive(Debug)] +struct Mlp { + fc1: Linear, + fc2: Linear, +} + +impl Mlp { + fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> { + let out_features = in_features; + let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?; + let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.fc1.forward(xs)?.gelu()?; + self.fc2.forward(&xs) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + ls1: LayerScale, + norm2: LayerNorm, + mlp: Mlp, + ls2: LayerScale, +} + +impl Block { + fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> { + let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; + let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; + let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; + let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; + let ls2 = LayerScale::new(vb.pp("ls2"), dim)?; + Ok(Self { + norm1, + attn, + ls1, + norm2, + mlp, + ls2, + }) + } +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + let xs = self + .ls1 + .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = self + .ls2 + .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?; + xs + residual + } +} + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, + patch_size: (usize, usize), + num_patches: usize, +} + +impl PatchEmbed { + fn new( + vb: VarBuilder, + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + ) -> Result<Self> { + let config = candle_nn::Conv2dConfig { + stride: patch_size, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?; + let num_patches = (img_size / patch_size) * (img_size / patch_size); + Ok(Self { + proj, + patch_size: (patch_size, patch_size), + num_patches, + }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (_b, _c, h, w) = xs.dims4()?; + let (patch_h, patch_w) = self.patch_size; + if (h % patch_h) != 0 { + candle::bail!("image height {h} is not a multiple of patch height {patch_h}") + } + if (w % patch_w) != 0 { + candle::bail!("image width {w} is not a multiple of patch width {patch_w}") + } + let xs = self.proj.forward(xs)?; + let (b, c, h, w) = xs.dims4()?; + // flatten embeddings. + xs.reshape((b, c, h * w))?.transpose(1, 2) + } +} + +#[derive(Debug)] +pub struct DinoVisionTransformer { + patch_embed: PatchEmbed, + cls_token: Tensor, + pos_embed: Tensor, + blocks: Vec<Block>, + norm: LayerNorm, + head: Linear, +} + +impl DinoVisionTransformer { + pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> { + let patch_embed = + PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?; + let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; + let num_tokens = 1; + let pos_embed = vb.get( + (1, patch_embed.num_patches + num_tokens, embed_dim), + "pos_embed", + )?; + let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?; + let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?; + let vb_b = vb.pp("blocks"); + let blocks = (0..depth) + .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .collect::<Result<Vec<_>>>()?; + Ok(Self { + patch_embed, + cls_token, + pos_embed, + blocks, + norm, + head, + }) + } + + fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> { + let npatch = xs.dim(1)? - 1; + let n = self.pos_embed.dim(1)? - 1; + let sqrt_n = (n as f64).sqrt(); + if npatch == n && w == h { + return Ok(xs.clone()); + } + let class_pos_embed = self.pos_embed.i((.., ..1))?; + let patch_pos_embed = self.pos_embed.i((.., 1..))?; + let dim = xs.dim(D::Minus1)?; + let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); + let patch_pos_embed = patch_pos_embed + .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? + .transpose(2, 3)? + .transpose(1, 2)?; + // This uses bicubic interpolation in the original implementation. + let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; + let el_count = patch_pos_embed.shape().elem_count(); + let patch_pos_embed = + patch_pos_embed + .transpose(1, 2)? + .transpose(2, 3)? + .reshape((1, el_count / dim, dim))?; + Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1) + } + + fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> { + let (_b, _nc, w, h) = xs.dims4()?; + let xs = self.patch_embed.forward(xs)?; + let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; + &xs + &self.interpolate_pos_encoding(&xs, w, h)? + } +} + +impl Module for DinoVisionTransformer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.prepare_tokens_with_mask(xs)?; + for blk in self.blocks.iter() { + xs = blk.forward(&xs)? + } + let xs = self.norm.forward(&xs)?; + let xs_norm_clstoken = xs.i((.., 0))?; + let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?; + let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?; + self.head.forward(&xs) + } +} + +pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> { + DinoVisionTransformer::new(vb, 12, 384, 6) +} diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs new file mode 100644 index 00000000..ab51c76d --- /dev/null +++ b/candle-transformers/src/models/efficientnet.rs @@ -0,0 +1,331 @@ +use candle::{Result, Tensor, D}; +use candle_nn as nn; +use nn::{Module, VarBuilder}; + +// Based on the Python version from torchvision. +// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47 +#[derive(Debug, Clone, Copy)] +pub struct MBConvConfig { + expand_ratio: f64, + kernel: usize, + stride: usize, + input_channels: usize, + out_channels: usize, + num_layers: usize, +} + +fn make_divisible(v: f64, divisor: usize) -> usize { + let min_value = divisor; + let new_v = usize::max( + min_value, + (v + divisor as f64 * 0.5) as usize / divisor * divisor, + ); + if (new_v as f64) < 0.9 * v { + new_v + divisor + } else { + new_v + } +} + +fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> { + let bneck_conf = |e, k, s, i, o, n| { + let input_channels = make_divisible(i as f64 * width_mult, 8); + let out_channels = make_divisible(o as f64 * width_mult, 8); + let num_layers = (n as f64 * depth_mult).ceil() as usize; + MBConvConfig { + expand_ratio: e, + kernel: k, + stride: s, + input_channels, + out_channels, + num_layers, + } + }; + vec![ + bneck_conf(1., 3, 1, 32, 16, 1), + bneck_conf(6., 3, 2, 16, 24, 2), + bneck_conf(6., 5, 2, 24, 40, 2), + bneck_conf(6., 3, 2, 40, 80, 3), + bneck_conf(6., 5, 1, 80, 112, 3), + bneck_conf(6., 5, 2, 112, 192, 4), + bneck_conf(6., 3, 1, 192, 320, 1), + ] +} + +impl MBConvConfig { + pub fn b0() -> Vec<Self> { + bneck_confs(1.0, 1.0) + } + pub fn b1() -> Vec<Self> { + bneck_confs(1.0, 1.1) + } + pub fn b2() -> Vec<Self> { + bneck_confs(1.1, 1.2) + } + pub fn b3() -> Vec<Self> { + bneck_confs(1.2, 1.4) + } + pub fn b4() -> Vec<Self> { + bneck_confs(1.4, 1.8) + } + pub fn b5() -> Vec<Self> { + bneck_confs(1.6, 2.2) + } + pub fn b6() -> Vec<Self> { + bneck_confs(1.8, 2.6) + } + pub fn b7() -> Vec<Self> { + bneck_confs(2.0, 3.1) + } +} + +/// Conv2D with same padding. +#[derive(Debug)] +struct Conv2DSame { + conv2d: nn::Conv2d, + s: usize, + k: usize, +} + +impl Conv2DSame { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + bias: bool, + ) -> Result<Self> { + let conv_config = nn::Conv2dConfig { + stride, + groups, + ..Default::default() + }; + let conv2d = if bias { + nn::conv2d(i, o, k, conv_config, vb)? + } else { + nn::conv2d_no_bias(i, o, k, conv_config, vb)? + }; + Ok(Self { + conv2d, + s: stride, + k, + }) + } +} + +impl Module for Conv2DSame { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let s = self.s; + let k = self.k; + let (_, _, ih, iw) = xs.dims4()?; + let oh = (ih + s - 1) / s; + let ow = (iw + s - 1) / s; + let pad_h = usize::max((oh - 1) * s + k - ih, 0); + let pad_w = usize::max((ow - 1) * s + k - iw, 0); + if pad_h > 0 || pad_w > 0 { + let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?; + let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?; + self.conv2d.forward(&xs) + } else { + self.conv2d.forward(xs) + } + } +} + +#[derive(Debug)] +struct ConvNormActivation { + conv2d: Conv2DSame, + bn2d: nn::BatchNorm, + activation: bool, +} + +impl ConvNormActivation { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + ) -> Result<Self> { + let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?; + let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?; + Ok(Self { + conv2d, + bn2d, + activation: true, + }) + } + + fn no_activation(self) -> Self { + Self { + activation: false, + ..self + } + } +} + +impl Module for ConvNormActivation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.conv2d.forward(xs)?; + let xs = self.bn2d.forward(&xs)?; + if self.activation { + swish(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +struct SqueezeExcitation { + fc1: Conv2DSame, + fc2: Conv2DSame, +} + +impl SqueezeExcitation { + fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> { + let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?; + let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for SqueezeExcitation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + // equivalent to adaptive_avg_pool2d([1, 1]) + let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; + let xs = self.fc1.forward(&xs)?; + let xs = swish(&xs)?; + let xs = self.fc2.forward(&xs)?; + let xs = nn::ops::sigmoid(&xs)?; + residual.broadcast_mul(&xs) + } +} + +#[derive(Debug)] +struct MBConv { + expand_cna: Option<ConvNormActivation>, + depthwise_cna: ConvNormActivation, + squeeze_excitation: SqueezeExcitation, + project_cna: ConvNormActivation, + config: MBConvConfig, +} + +impl MBConv { + fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> { + let vb = vb.pp("block"); + let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8); + let expand_cna = if exp != c.input_channels { + Some(ConvNormActivation::new( + vb.pp("0"), + c.input_channels, + exp, + 1, + 1, + 1, + )?) + } else { + None + }; + let start_index = if expand_cna.is_some() { 1 } else { 0 }; + let depthwise_cna = + ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?; + let squeeze_channels = usize::max(1, c.input_channels / 4); + let squeeze_excitation = + SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?; + let project_cna = + ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)? + .no_activation(); + Ok(Self { + expand_cna, + depthwise_cna, + squeeze_excitation, + project_cna, + config: c, + }) + } +} + +impl Module for MBConv { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let use_res_connect = + self.config.stride == 1 && self.config.input_channels == self.config.out_channels; + let ys = match &self.expand_cna { + Some(expand_cna) => expand_cna.forward(xs)?, + None => xs.clone(), + }; + let ys = self.depthwise_cna.forward(&ys)?; + let ys = self.squeeze_excitation.forward(&ys)?; + let ys = self.project_cna.forward(&ys)?; + if use_res_connect { + ys + xs + } else { + Ok(ys) + } + } +} + +fn swish(s: &Tensor) -> Result<Tensor> { + s * nn::ops::sigmoid(s)? +} + +#[derive(Debug)] +pub struct EfficientNet { + init_cna: ConvNormActivation, + blocks: Vec<MBConv>, + final_cna: ConvNormActivation, + classifier: nn::Linear, +} + +impl EfficientNet { + pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> { + let f_p = p.pp("features"); + let first_in_c = configs[0].input_channels; + let last_out_c = configs.last().unwrap().out_channels; + let final_out_c = 4 * last_out_c; + let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; + let nconfigs = configs.len(); + let mut blocks = vec![]; + for (index, cnf) in configs.into_iter().enumerate() { + let f_p = f_p.pp(index + 1); + for r_index in 0..cnf.num_layers { + let cnf = if r_index == 0 { + cnf + } else { + MBConvConfig { + input_channels: cnf.out_channels, + stride: 1, + ..cnf + } + }; + blocks.push(MBConv::new(f_p.pp(r_index), cnf)?) + } + } + let final_cna = + ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?; + let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?; + Ok(Self { + init_cna, + blocks, + final_cna, + classifier, + }) + } +} + +impl Module for EfficientNet { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.init_cna.forward(xs)?; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + let xs = self.final_cna.forward(&xs)?; + // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1) + let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?; + self.classifier.forward(&xs) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 1b3dcf25..76e13b2a 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1,5 +1,9 @@ pub mod bert; pub mod bigcode; +pub mod dinov2; +pub mod efficientnet; pub mod falcon; pub mod llama; +pub mod quantized_llama; +pub mod segment_anything; pub mod whisper; diff --git a/candle-examples/examples/quantized/model.rs b/candle-transformers/src/models/quantized_llama.rs index da0bd0b0..da0bd0b0 100644 --- a/candle-examples/examples/quantized/model.rs +++ b/candle-transformers/src/models/quantized_llama.rs diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-transformers/src/models/segment_anything/image_encoder.rs index 76cd15d0..0b313830 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-transformers/src/models/segment_anything/image_encoder.rs @@ -100,8 +100,8 @@ impl candle::CustomOp3 for Add3 { #[derive(Debug)] struct Attention { - qkv: crate::Linear, - proj: crate::Linear, + qkv: super::Linear, + proj: super::Linear, num_heads: usize, scale: f64, rel_pos_hw: Option<(Tensor, Tensor)>, @@ -124,8 +124,8 @@ impl Attention { let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos"); let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); - let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; - let proj = crate::linear(vb.pp("proj"), dim, dim, true)?; + let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = super::linear(vb.pp("proj"), dim, dim, true)?; let head_dim = dim / num_heads; let scale = 1. / (head_dim as f64).sqrt(); let rel_pos_hw = if use_rel_pos { @@ -251,7 +251,7 @@ struct Block { norm1: LayerNorm, attn: Attention, norm2: LayerNorm, - mlp: crate::MlpBlock, + mlp: super::MlpBlock, window_size: usize, span: tracing::Span, } @@ -281,7 +281,7 @@ impl Block { input_size_attn, vb.pp("attn"), )?; - let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; + let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; let span = tracing::span!(tracing::Level::TRACE, "ie-block"); Ok(Self { norm1, @@ -375,9 +375,9 @@ pub struct ImageEncoderViT { patch_embed: PatchEmbed, blocks: Vec<Block>, neck_conv1: candle_nn::Conv2d, - neck_ln1: crate::LayerNorm2d, + neck_ln1: super::LayerNorm2d, neck_conv2: candle_nn::Conv2d, - neck_ln2: crate::LayerNorm2d, + neck_ln2: super::LayerNorm2d, pos_embed: Option<Tensor>, span: tracing::Span, } @@ -433,13 +433,13 @@ impl ImageEncoderViT { Default::default(), vb.pp("neck.0"), )?; - let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?; + let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?; let cfg = candle_nn::Conv2dConfig { padding: 1, ..Default::default() }; let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; - let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?; + let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?; let pos_embed = if use_abs_pos { let p = vb.get( (1, img_size / patch_size, img_size / patch_size, embed_dim), diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-transformers/src/models/segment_anything/mask_decoder.rs index c02b44a7..2a91cd44 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-transformers/src/models/segment_anything/mask_decoder.rs @@ -1,11 +1,11 @@ use candle::{IndexOp, Result, Tensor}; use candle_nn::{Module, VarBuilder}; -use crate::model_transformer::TwoWayTransformer; +use super::transformer::TwoWayTransformer; #[derive(Debug)] struct MlpMaskDecoder { - layers: Vec<crate::Linear>, + layers: Vec<super::Linear>, sigmoid_output: bool, span: tracing::Span, } @@ -28,7 +28,7 @@ impl MlpMaskDecoder { } else { hidden_dim }; - let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?; + let layer = super::linear(vb.pp(i), in_dim, out_dim, true)?; layers.push(layer) } let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder"); @@ -64,7 +64,7 @@ pub struct MaskDecoder { mask_tokens: candle_nn::Embedding, iou_prediction_head: MlpMaskDecoder, output_upscaling_conv1: candle_nn::ConvTranspose2d, - output_upscaling_ln: crate::LayerNorm2d, + output_upscaling_ln: super::LayerNorm2d, output_upscaling_conv2: candle_nn::ConvTranspose2d, num_mask_tokens: usize, output_hypernetworks_mlps: Vec<MlpMaskDecoder>, @@ -104,7 +104,7 @@ impl MaskDecoder { vb.pp("output_upscaling.0"), )?; let output_upscaling_ln = - crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; + super::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; let output_upscaling_conv2 = candle_nn::conv_transpose2d( transformer_dim / 4, transformer_dim / 8, diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs new file mode 100644 index 00000000..c29db70a --- /dev/null +++ b/candle-transformers/src/models/segment_anything/mod.rs @@ -0,0 +1,100 @@ +use candle::{Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +pub mod image_encoder; +pub mod mask_decoder; +pub mod prompt_encoder; +pub mod sam; +pub mod tiny_vit; +pub mod transformer; + +pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { + let inner = if bias { + candle_nn::linear(in_dim, out_dim, vb)? + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb)? + }; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { inner, span }) +} + +#[derive(Debug)] +pub struct LayerNorm2d { + weight: Tensor, + bias: Tensor, + num_channels: usize, + eps: f64, +} + +impl LayerNorm2d { + pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(num_channels, "weight")?; + let bias = vb.get(num_channels, "bias")?; + Ok(Self { + weight, + bias, + num_channels, + eps, + }) + } +} + +impl Module for LayerNorm2d { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let u = xs.mean_keepdim(1)?; + let xs = xs.broadcast_sub(&u)?; + let s = xs.sqr()?.mean_keepdim(1)?; + let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; + xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? + .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) + } +} + +#[derive(Debug)] +pub struct MlpBlock { + lin1: Linear, + lin2: Linear, + activation: candle_nn::Activation, + span: tracing::Span, +} + +impl MlpBlock { + pub fn new( + embedding_dim: usize, + mlp_dim: usize, + activation: candle_nn::Activation, + vb: VarBuilder, + ) -> Result<Self> { + let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; + let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); + Ok(Self { + lin1, + lin2, + activation, + span, + }) + } +} + +impl Module for MlpBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.lin1)? + .apply(&self.activation)? + .apply(&self.lin2) + } +} + +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs index 7bbe8419..9d0074b1 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs @@ -56,9 +56,9 @@ pub struct PromptEncoder { point_embeddings: Vec<candle_nn::Embedding>, not_a_point_embed: candle_nn::Embedding, mask_downscaling_conv1: candle_nn::Conv2d, - mask_downscaling_ln1: crate::LayerNorm2d, + mask_downscaling_ln1: super::LayerNorm2d, mask_downscaling_conv2: candle_nn::Conv2d, - mask_downscaling_ln2: crate::LayerNorm2d, + mask_downscaling_ln2: super::LayerNorm2d, mask_downscaling_conv3: candle_nn::Conv2d, no_mask_embed: candle_nn::Embedding, image_embedding_size: (usize, usize), @@ -100,9 +100,9 @@ impl PromptEncoder { vb.pp("mask_downscaling.6"), )?; let mask_downscaling_ln1 = - crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; + super::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; let mask_downscaling_ln2 = - crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; + super::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; let mut point_embeddings = Vec::with_capacity(num_points_embeddings); let vb_e = vb.pp("point_embeddings"); for i in 0..num_points_embeddings { diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-transformers/src/models/segment_anything/sam.rs index b1a81af6..c40473e3 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -1,10 +1,10 @@ use candle::{DType, IndexOp, Result, Tensor}; use candle_nn::{Module, VarBuilder}; -use crate::model_image_encoder::ImageEncoderViT; -use crate::model_mask_decoder::MaskDecoder; -use crate::model_prompt_encoder::PromptEncoder; -use crate::model_tiny_vit::{tiny_vit_5m, TinyViT}; +use super::image_encoder::ImageEncoderViT; +use super::mask_decoder::MaskDecoder; +use super::prompt_encoder::PromptEncoder; +use super::tiny_vit::{tiny_vit_5m, TinyViT}; const PROMPT_EMBED_DIM: usize = 256; pub const IMAGE_SIZE: usize = 1024; @@ -186,7 +186,7 @@ impl Sam { img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)], - ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> { + ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> { // Crop the image and calculate embeddings. let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?; let img = self.preprocess(&img)?.unsqueeze(0)?; @@ -259,7 +259,7 @@ impl Sam { let min_max_x = min_max_indexes(&low_res_mask_per_x); let min_max_y = min_max_indexes(&low_res_mask_per_y); if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) { - let bbox = candle_examples::object_detection::Bbox { + let bbox = crate::object_detection::Bbox { xmin: x0 as f32, ymin: y0 as f32, xmax: x1 as f32, @@ -277,7 +277,7 @@ impl Sam { let mut bboxes = vec![bboxes]; // Remove duplicates within this crop. - candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH); + crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH); // TODO: Return to the original image frame. Ok(bboxes.remove(0)) @@ -290,7 +290,7 @@ impl Sam { crop_n_layer: usize, crop_overlap_ratio: f64, crop_n_points_downscale_factor: usize, - ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> { + ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> { let (_c, h, w) = img.dims3()?; let point_grids = build_all_layer_point_grids( points_per_side, diff --git a/candle-examples/examples/segment-anything/model_tiny_vit.rs b/candle-transformers/src/models/segment_anything/tiny_vit.rs index ff076773..cd2936ab 100644 --- a/candle-examples/examples/segment-anything/model_tiny_vit.rs +++ b/candle-transformers/src/models/segment_anything/tiny_vit.rs @@ -215,16 +215,16 @@ impl Module for ConvLayer { #[derive(Debug)] struct Mlp { norm: candle_nn::LayerNorm, - fc1: crate::Linear, - fc2: crate::Linear, + fc1: super::Linear, + fc2: super::Linear, span: tracing::Span, } impl Mlp { fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> { let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?; - let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?; - let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?; + let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?; + let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?; let span = tracing::span!(tracing::Level::TRACE, "mlp"); Ok(Self { norm, @@ -248,8 +248,8 @@ impl Module for Mlp { #[derive(Debug)] struct Attention { norm: candle_nn::LayerNorm, - qkv: crate::Linear, - proj: crate::Linear, + qkv: super::Linear, + proj: super::Linear, ab: Tensor, key_dim: usize, num_heads: usize, @@ -275,8 +275,8 @@ impl Attention { let nh_kd = key_dim * num_heads; let h = dh + nh_kd * 2; let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?; - let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?; - let proj = crate::linear(vb.pp("proj"), dh, dim, true)?; + let qkv = super::linear(vb.pp("qkv"), dim, h, true)?; + let proj = super::linear(vb.pp("proj"), dh, dim, true)?; let points = (0..resolution.0) .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64))) @@ -526,9 +526,9 @@ pub struct TinyViT { // norm_head: candle_nn::LayerNorm, // head: candle_nn::Linear, neck_conv1: candle_nn::Conv2d, - neck_ln1: crate::LayerNorm2d, + neck_ln1: super::LayerNorm2d, neck_conv2: candle_nn::Conv2d, - neck_ln2: crate::LayerNorm2d, + neck_ln2: super::LayerNorm2d, span: tracing::Span, span_neck: tracing::Span, } @@ -578,13 +578,13 @@ impl TinyViT { // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?; let neck_conv1 = candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?; - let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?; + let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?; let cfg = candle_nn::Conv2dConfig { padding: 1, ..Default::default() }; let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?; - let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?; + let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?; let span = tracing::span!(tracing::Level::TRACE, "tiny-vit"); let span_neck = tracing::span!(tracing::Level::TRACE, "neck"); diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-transformers/src/models/segment_anything/transformer.rs index e12aac08..80efb38c 100644 --- a/candle-examples/examples/segment-anything/model_transformer.rs +++ b/candle-transformers/src/models/segment_anything/transformer.rs @@ -68,7 +68,7 @@ struct TwoWayAttentionBlock { norm1: LayerNorm, cross_attn_token_to_image: Attention, norm2: LayerNorm, - mlp: crate::MlpBlock, + mlp: super::MlpBlock, norm3: LayerNorm, norm4: LayerNorm, cross_attn_image_to_token: Attention, @@ -100,7 +100,7 @@ impl TwoWayAttentionBlock { 2, vb.pp("cross_attn_image_to_token"), )?; - let mlp = crate::MlpBlock::new( + let mlp = super::MlpBlock::new( embedding_dim, mlp_dim, candle_nn::Activation::Relu, diff --git a/candle-examples/src/object_detection.rs b/candle-transformers/src/object_detection.rs index ce579316..ce579316 100644 --- a/candle-examples/src/object_detection.rs +++ b/candle-transformers/src/object_detection.rs |