summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-10 10:20:18 +0100
committerGitHub <noreply@github.com>2023-09-10 10:20:18 +0100
commit35f72514f59b3fa4bd321e3e88a75f5b43cf060f (patch)
tree37dd25098bcf16293744758268a0486337d18431
parentd3f05eae8c4f2df186b46e433be101ac39fceca5 (diff)
downloadcandle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.tar.gz
candle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.tar.bz2
candle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.zip
Move more models to candle-transformers (#796)
* Move dinov2. * Move efficientnet. * Move the quantized llama model. * Move segment-anything.
-rw-r--r--candle-examples/examples/dinov2/main.rs283
-rw-r--r--candle-examples/examples/efficientnet/main.rs335
-rw-r--r--candle-examples/examples/quantized/main.rs2
-rw-r--r--candle-examples/examples/segment-anything/main.rs109
-rw-r--r--candle-examples/examples/yolo-v3/main.rs2
-rw-r--r--candle-examples/examples/yolo-v8/main.rs2
-rw-r--r--candle-examples/src/lib.rs1
-rw-r--r--candle-transformers/Cargo.toml1
-rw-r--r--candle-transformers/src/lib.rs1
-rw-r--r--candle-transformers/src/models/dinov2.rs279
-rw-r--r--candle-transformers/src/models/efficientnet.rs331
-rw-r--r--candle-transformers/src/models/mod.rs4
-rw-r--r--candle-transformers/src/models/quantized_llama.rs (renamed from candle-examples/examples/quantized/model.rs)0
-rw-r--r--candle-transformers/src/models/segment_anything/image_encoder.rs (renamed from candle-examples/examples/segment-anything/model_image_encoder.rs)20
-rw-r--r--candle-transformers/src/models/segment_anything/mask_decoder.rs (renamed from candle-examples/examples/segment-anything/model_mask_decoder.rs)10
-rw-r--r--candle-transformers/src/models/segment_anything/mod.rs100
-rw-r--r--candle-transformers/src/models/segment_anything/prompt_encoder.rs (renamed from candle-examples/examples/segment-anything/model_prompt_encoder.rs)8
-rw-r--r--candle-transformers/src/models/segment_anything/sam.rs (renamed from candle-examples/examples/segment-anything/model_sam.rs)16
-rw-r--r--candle-transformers/src/models/segment_anything/tiny_vit.rs (renamed from candle-examples/examples/segment-anything/model_tiny_vit.rs)24
-rw-r--r--candle-transformers/src/models/segment_anything/transformer.rs (renamed from candle-examples/examples/segment-anything/model_transformer.rs)4
-rw-r--r--candle-transformers/src/object_detection.rs (renamed from candle-examples/src/object_detection.rs)0
21 files changed, 773 insertions, 759 deletions
diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs
index e80c81e2..d3adb37c 100644
--- a/candle-examples/examples/dinov2/main.rs
+++ b/candle-examples/examples/dinov2/main.rs
@@ -9,285 +9,10 @@ extern crate accelerate_src;
use clap::Parser;
-use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use candle::{DType, IndexOp, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::dinov2;
-const IMG_SIZE: usize = 518;
-const PATCH_SIZE: usize = 14;
-const NUM_CLASSES: usize = 1000;
-
-fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
- if bias {
- candle_nn::linear(in_dim, out_dim, vb)
- } else {
- candle_nn::linear_no_bias(in_dim, out_dim, vb)
- }
-}
-
-#[derive(Debug)]
-struct Attention {
- qkv: Linear,
- proj: Linear,
- num_heads: usize,
- scale: f64,
-}
-
-impl Attention {
- fn new(
- vb: VarBuilder,
- dim: usize,
- num_heads: usize,
- qkv_bias: bool,
- proj_bias: bool,
- ) -> Result<Self> {
- let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
- let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
- let scale = 1. / ((dim / num_heads) as f64).sqrt();
- Ok(Self {
- qkv,
- proj,
- num_heads,
- scale,
- })
- }
-}
-
-impl Module for Attention {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let (b, n, c) = xs.dims3()?;
- let qkv = self
- .qkv
- .forward(xs)?
- .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
- .transpose(1, 2)? // 02134
- .transpose(0, 1)? // 20134
- .transpose(2, 3)?; // 20314
- let q = (qkv.i(0)? * self.scale)?;
- let k = qkv.i(1)?;
- let v = qkv.i(2)?;
- let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
- let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
- self.proj.forward(&attn)
- }
-}
-
-#[derive(Debug)]
-struct LayerScale {
- gamma: Tensor,
-}
-
-impl LayerScale {
- fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
- let gamma = vb.get(dim, "gamma")?;
- Ok(Self { gamma })
- }
-}
-
-impl Module for LayerScale {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- xs.broadcast_mul(&self.gamma)
- }
-}
-
-#[derive(Debug)]
-struct Mlp {
- fc1: Linear,
- fc2: Linear,
-}
-
-impl Mlp {
- fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
- let out_features = in_features;
- let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
- let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
- Ok(Self { fc1, fc2 })
- }
-}
-
-impl Module for Mlp {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let xs = self.fc1.forward(xs)?.gelu()?;
- self.fc2.forward(&xs)
- }
-}
-
-#[derive(Debug)]
-struct Block {
- norm1: LayerNorm,
- attn: Attention,
- ls1: LayerScale,
- norm2: LayerNorm,
- mlp: Mlp,
- ls2: LayerScale,
-}
-
-impl Block {
- fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
- let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
- let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
- let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
- let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
- let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
- let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
- Ok(Self {
- norm1,
- attn,
- ls1,
- norm2,
- mlp,
- ls2,
- })
- }
-}
-
-impl Module for Block {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let residual = xs;
- let xs = self
- .ls1
- .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
- let xs = (xs + residual)?;
- let residual = &xs;
- let xs = self
- .ls2
- .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
- xs + residual
- }
-}
-
-#[derive(Debug)]
-struct PatchEmbed {
- proj: candle_nn::Conv2d,
- patch_size: (usize, usize),
- num_patches: usize,
-}
-
-impl PatchEmbed {
- fn new(
- vb: VarBuilder,
- img_size: usize,
- patch_size: usize,
- in_chans: usize,
- embed_dim: usize,
- ) -> Result<Self> {
- let config = candle_nn::Conv2dConfig {
- stride: patch_size,
- ..Default::default()
- };
- let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
- let num_patches = (img_size / patch_size) * (img_size / patch_size);
- Ok(Self {
- proj,
- patch_size: (patch_size, patch_size),
- num_patches,
- })
- }
-}
-
-impl Module for PatchEmbed {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let (_b, _c, h, w) = xs.dims4()?;
- let (patch_h, patch_w) = self.patch_size;
- if (h % patch_h) != 0 {
- candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
- }
- if (w % patch_w) != 0 {
- candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
- }
- let xs = self.proj.forward(xs)?;
- let (b, c, h, w) = xs.dims4()?;
- // flatten embeddings.
- xs.reshape((b, c, h * w))?.transpose(1, 2)
- }
-}
-
-#[derive(Debug)]
-pub struct DinoVisionTransformer {
- patch_embed: PatchEmbed,
- cls_token: Tensor,
- pos_embed: Tensor,
- blocks: Vec<Block>,
- norm: LayerNorm,
- head: Linear,
-}
-
-impl DinoVisionTransformer {
- pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
- let patch_embed =
- PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
- let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
- let num_tokens = 1;
- let pos_embed = vb.get(
- (1, patch_embed.num_patches + num_tokens, embed_dim),
- "pos_embed",
- )?;
- let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
- let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
- let vb_b = vb.pp("blocks");
- let blocks = (0..depth)
- .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
- .collect::<Result<Vec<_>>>()?;
- Ok(Self {
- patch_embed,
- cls_token,
- pos_embed,
- blocks,
- norm,
- head,
- })
- }
-
- fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
- let npatch = xs.dim(1)? - 1;
- let n = self.pos_embed.dim(1)? - 1;
- let sqrt_n = (n as f64).sqrt();
- if npatch == n && w == h {
- return Ok(xs.clone());
- }
- let class_pos_embed = self.pos_embed.i((.., ..1))?;
- let patch_pos_embed = self.pos_embed.i((.., 1..))?;
- let dim = xs.dim(D::Minus1)?;
- let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
- let patch_pos_embed = patch_pos_embed
- .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
- .transpose(2, 3)?
- .transpose(1, 2)?;
- // This uses bicubic interpolation in the original implementation.
- let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
- let el_count = patch_pos_embed.shape().elem_count();
- let patch_pos_embed =
- patch_pos_embed
- .transpose(1, 2)?
- .transpose(2, 3)?
- .reshape((1, el_count / dim, dim))?;
- Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
- }
-
- fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
- let (_b, _nc, w, h) = xs.dims4()?;
- let xs = self.patch_embed.forward(xs)?;
- let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
- &xs + &self.interpolate_pos_encoding(&xs, w, h)?
- }
-}
-
-impl Module for DinoVisionTransformer {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let mut xs = self.prepare_tokens_with_mask(xs)?;
- for blk in self.blocks.iter() {
- xs = blk.forward(&xs)?
- }
- let xs = self.norm.forward(&xs)?;
- let xs_norm_clstoken = xs.i((.., 0))?;
- let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
- let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
- self.head.forward(&xs)
- }
-}
-
-pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
- DinoVisionTransformer::new(vb, 12, 384, 6)
-}
#[derive(Parser)]
struct Args {
#[arg(long)]
@@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
- let model = vit_small(vb)?;
+ let model = dinov2::vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs
index cbe2c90a..1e45e301 100644
--- a/candle-examples/examples/efficientnet/main.rs
+++ b/candle-examples/examples/efficientnet/main.rs
@@ -8,340 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
+use candle::{DType, IndexOp, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
use clap::{Parser, ValueEnum};
-use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn as nn;
-use nn::{Module, VarBuilder};
-
-// Based on the Python version from torchvision.
-// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
-#[derive(Debug, Clone, Copy)]
-pub struct MBConvConfig {
- expand_ratio: f64,
- kernel: usize,
- stride: usize,
- input_channels: usize,
- out_channels: usize,
- num_layers: usize,
-}
-
-fn make_divisible(v: f64, divisor: usize) -> usize {
- let min_value = divisor;
- let new_v = usize::max(
- min_value,
- (v + divisor as f64 * 0.5) as usize / divisor * divisor,
- );
- if (new_v as f64) < 0.9 * v {
- new_v + divisor
- } else {
- new_v
- }
-}
-
-fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
- let bneck_conf = |e, k, s, i, o, n| {
- let input_channels = make_divisible(i as f64 * width_mult, 8);
- let out_channels = make_divisible(o as f64 * width_mult, 8);
- let num_layers = (n as f64 * depth_mult).ceil() as usize;
- MBConvConfig {
- expand_ratio: e,
- kernel: k,
- stride: s,
- input_channels,
- out_channels,
- num_layers,
- }
- };
- vec![
- bneck_conf(1., 3, 1, 32, 16, 1),
- bneck_conf(6., 3, 2, 16, 24, 2),
- bneck_conf(6., 5, 2, 24, 40, 2),
- bneck_conf(6., 3, 2, 40, 80, 3),
- bneck_conf(6., 5, 1, 80, 112, 3),
- bneck_conf(6., 5, 2, 112, 192, 4),
- bneck_conf(6., 3, 1, 192, 320, 1),
- ]
-}
-
-impl MBConvConfig {
- fn b0() -> Vec<Self> {
- bneck_confs(1.0, 1.0)
- }
- fn b1() -> Vec<Self> {
- bneck_confs(1.0, 1.1)
- }
- fn b2() -> Vec<Self> {
- bneck_confs(1.1, 1.2)
- }
- fn b3() -> Vec<Self> {
- bneck_confs(1.2, 1.4)
- }
- fn b4() -> Vec<Self> {
- bneck_confs(1.4, 1.8)
- }
- fn b5() -> Vec<Self> {
- bneck_confs(1.6, 2.2)
- }
- fn b6() -> Vec<Self> {
- bneck_confs(1.8, 2.6)
- }
- fn b7() -> Vec<Self> {
- bneck_confs(2.0, 3.1)
- }
-}
-
-/// Conv2D with same padding.
-#[derive(Debug)]
-struct Conv2DSame {
- conv2d: nn::Conv2d,
- s: usize,
- k: usize,
-}
-
-impl Conv2DSame {
- fn new(
- vb: VarBuilder,
- i: usize,
- o: usize,
- k: usize,
- stride: usize,
- groups: usize,
- bias: bool,
- ) -> Result<Self> {
- let conv_config = nn::Conv2dConfig {
- stride,
- groups,
- ..Default::default()
- };
- let conv2d = if bias {
- nn::conv2d(i, o, k, conv_config, vb)?
- } else {
- nn::conv2d_no_bias(i, o, k, conv_config, vb)?
- };
- Ok(Self {
- conv2d,
- s: stride,
- k,
- })
- }
-}
-
-impl Module for Conv2DSame {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let s = self.s;
- let k = self.k;
- let (_, _, ih, iw) = xs.dims4()?;
- let oh = (ih + s - 1) / s;
- let ow = (iw + s - 1) / s;
- let pad_h = usize::max((oh - 1) * s + k - ih, 0);
- let pad_w = usize::max((ow - 1) * s + k - iw, 0);
- if pad_h > 0 || pad_w > 0 {
- let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
- let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
- self.conv2d.forward(&xs)
- } else {
- self.conv2d.forward(xs)
- }
- }
-}
-
-#[derive(Debug)]
-struct ConvNormActivation {
- conv2d: Conv2DSame,
- bn2d: nn::BatchNorm,
- activation: bool,
-}
-
-impl ConvNormActivation {
- fn new(
- vb: VarBuilder,
- i: usize,
- o: usize,
- k: usize,
- stride: usize,
- groups: usize,
- ) -> Result<Self> {
- let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
- let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
- Ok(Self {
- conv2d,
- bn2d,
- activation: true,
- })
- }
-
- fn no_activation(self) -> Self {
- Self {
- activation: false,
- ..self
- }
- }
-}
-
-impl Module for ConvNormActivation {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let xs = self.conv2d.forward(xs)?;
- let xs = self.bn2d.forward(&xs)?;
- if self.activation {
- swish(&xs)
- } else {
- Ok(xs)
- }
- }
-}
-
-#[derive(Debug)]
-struct SqueezeExcitation {
- fc1: Conv2DSame,
- fc2: Conv2DSame,
-}
-
-impl SqueezeExcitation {
- fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
- let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
- let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
- Ok(Self { fc1, fc2 })
- }
-}
-
-impl Module for SqueezeExcitation {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let residual = xs;
- // equivalent to adaptive_avg_pool2d([1, 1])
- let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
- let xs = self.fc1.forward(&xs)?;
- let xs = swish(&xs)?;
- let xs = self.fc2.forward(&xs)?;
- let xs = nn::ops::sigmoid(&xs)?;
- residual.broadcast_mul(&xs)
- }
-}
-
-#[derive(Debug)]
-struct MBConv {
- expand_cna: Option<ConvNormActivation>,
- depthwise_cna: ConvNormActivation,
- squeeze_excitation: SqueezeExcitation,
- project_cna: ConvNormActivation,
- config: MBConvConfig,
-}
-
-impl MBConv {
- fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
- let vb = vb.pp("block");
- let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
- let expand_cna = if exp != c.input_channels {
- Some(ConvNormActivation::new(
- vb.pp("0"),
- c.input_channels,
- exp,
- 1,
- 1,
- 1,
- )?)
- } else {
- None
- };
- let start_index = if expand_cna.is_some() { 1 } else { 0 };
- let depthwise_cna =
- ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
- let squeeze_channels = usize::max(1, c.input_channels / 4);
- let squeeze_excitation =
- SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
- let project_cna =
- ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
- .no_activation();
- Ok(Self {
- expand_cna,
- depthwise_cna,
- squeeze_excitation,
- project_cna,
- config: c,
- })
- }
-}
-
-impl Module for MBConv {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let use_res_connect =
- self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
- let ys = match &self.expand_cna {
- Some(expand_cna) => expand_cna.forward(xs)?,
- None => xs.clone(),
- };
- let ys = self.depthwise_cna.forward(&ys)?;
- let ys = self.squeeze_excitation.forward(&ys)?;
- let ys = self.project_cna.forward(&ys)?;
- if use_res_connect {
- ys + xs
- } else {
- Ok(ys)
- }
- }
-}
-
-fn swish(s: &Tensor) -> Result<Tensor> {
- s * nn::ops::sigmoid(s)?
-}
-
-#[derive(Debug)]
-struct EfficientNet {
- init_cna: ConvNormActivation,
- blocks: Vec<MBConv>,
- final_cna: ConvNormActivation,
- classifier: nn::Linear,
-}
-
-impl EfficientNet {
- fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
- let f_p = p.pp("features");
- let first_in_c = configs[0].input_channels;
- let last_out_c = configs.last().unwrap().out_channels;
- let final_out_c = 4 * last_out_c;
- let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
- let nconfigs = configs.len();
- let mut blocks = vec![];
- for (index, cnf) in configs.into_iter().enumerate() {
- let f_p = f_p.pp(index + 1);
- for r_index in 0..cnf.num_layers {
- let cnf = if r_index == 0 {
- cnf
- } else {
- MBConvConfig {
- input_channels: cnf.out_channels,
- stride: 1,
- ..cnf
- }
- };
- blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
- }
- }
- let final_cna =
- ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
- let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
- Ok(Self {
- init_cna,
- blocks,
- final_cna,
- classifier,
- })
- }
-}
-
-impl Module for EfficientNet {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let mut xs = self.init_cna.forward(xs)?;
- for block in self.blocks.iter() {
- xs = block.forward(&xs)?
- }
- let xs = self.final_cna.forward(&xs)?;
- // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
- let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
- self.classifier.forward(&xs)
- }
-}
-
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
B0,
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index a3f98d8e..c8179d33 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
-mod model;
+use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 9ce2f158..21ba0415 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -7,108 +7,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-pub mod model_image_encoder;
-pub mod model_mask_decoder;
-pub mod model_prompt_encoder;
-pub mod model_sam;
-pub mod model_tiny_vit;
-pub mod model_transformer;
-
-use candle::{DType, Result, Tensor};
-use candle_nn::{Module, VarBuilder};
+use candle::DType;
+use candle_nn::VarBuilder;
+use candle_transformers::models::segment_anything::sam;
use clap::Parser;
-pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
- let inner = if bias {
- candle_nn::linear(in_dim, out_dim, vb)?
- } else {
- candle_nn::linear_no_bias(in_dim, out_dim, vb)?
- };
- let span = tracing::span!(tracing::Level::TRACE, "linear");
- Ok(Linear { inner, span })
-}
-
-#[derive(Debug)]
-pub struct LayerNorm2d {
- weight: Tensor,
- bias: Tensor,
- num_channels: usize,
- eps: f64,
-}
-
-impl LayerNorm2d {
- pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
- let weight = vb.get(num_channels, "weight")?;
- let bias = vb.get(num_channels, "bias")?;
- Ok(Self {
- weight,
- bias,
- num_channels,
- eps,
- })
- }
-}
-
-impl Module for LayerNorm2d {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let u = xs.mean_keepdim(1)?;
- let xs = xs.broadcast_sub(&u)?;
- let s = xs.sqr()?.mean_keepdim(1)?;
- let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
- xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
- .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
- }
-}
-
-#[derive(Debug)]
-pub struct MlpBlock {
- lin1: Linear,
- lin2: Linear,
- activation: candle_nn::Activation,
- span: tracing::Span,
-}
-
-impl MlpBlock {
- pub fn new(
- embedding_dim: usize,
- mlp_dim: usize,
- activation: candle_nn::Activation,
- vb: VarBuilder,
- ) -> Result<Self> {
- let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
- let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
- let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
- Ok(Self {
- lin1,
- lin2,
- activation,
- span,
- })
- }
-}
-
-impl Module for MlpBlock {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- xs.apply(&self.lin1)?
- .apply(&self.activation)?
- .apply(&self.lin2)
- }
-}
-
-#[derive(Debug)]
-pub struct Linear {
- inner: candle_nn::Linear,
- span: tracing::Span,
-}
-
-impl Module for Linear {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
#[derive(Parser)]
struct Args {
#[arg(long)]
@@ -173,7 +76,7 @@ pub fn main() -> anyhow::Result<()> {
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
- let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
+ let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@@ -195,9 +98,9 @@ pub fn main() -> anyhow::Result<()> {
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = if args.use_tiny {
- model_sam::Sam::new_tiny(vb)? // tiny vit_t
+ sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
- model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
+ sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
if args.generate_masks {
diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs
index 20021b45..ecf75bdf 100644
--- a/candle-examples/examples/yolo-v3/main.rs
+++ b/candle-examples/examples/yolo-v3/main.rs
@@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use candle_examples::object_detection::{non_maximum_suppression, Bbox};
+use candle_transformers::object_detection::{non_maximum_suppression, Bbox};
mod darknet;
use anyhow::Result;
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index 2017b5be..d48bac35 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -8,8 +8,8 @@ mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, IndexOp, Result, Tensor};
-use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use candle_nn::{Module, VarBuilder};
+use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
use image::DynamicImage;
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index c14b2d6b..5e0b44fb 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -1,6 +1,5 @@
pub mod coco_classes;
pub mod imagenet;
-pub mod object_detection;
use candle::{Device, Result, Tensor};
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml
index 6b2087cb..86caf918 100644
--- a/candle-transformers/Cargo.toml
+++ b/candle-transformers/Cargo.toml
@@ -16,6 +16,7 @@ candle-nn = { path = "../candle-nn", version = "0.2.1" }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rand = { workspace = true }
+rayon = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tracing = { workspace = true }
diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs
index a8890dc8..b83e5056 100644
--- a/candle-transformers/src/lib.rs
+++ b/candle-transformers/src/lib.rs
@@ -1,4 +1,5 @@
pub mod generation;
pub mod models;
+pub mod object_detection;
pub mod pipelines;
pub mod utils;
diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs
new file mode 100644
index 00000000..0edc8494
--- /dev/null
+++ b/candle-transformers/src/models/dinov2.rs
@@ -0,0 +1,279 @@
+use candle::{IndexOp, Result, Tensor, D};
+use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+
+const IMG_SIZE: usize = 518;
+const PATCH_SIZE: usize = 14;
+const NUM_CLASSES: usize = 1000;
+
+fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
+ if bias {
+ candle_nn::linear(in_dim, out_dim, vb)
+ } else {
+ candle_nn::linear_no_bias(in_dim, out_dim, vb)
+ }
+}
+
+#[derive(Debug)]
+struct Attention {
+ qkv: Linear,
+ proj: Linear,
+ num_heads: usize,
+ scale: f64,
+}
+
+impl Attention {
+ fn new(
+ vb: VarBuilder,
+ dim: usize,
+ num_heads: usize,
+ qkv_bias: bool,
+ proj_bias: bool,
+ ) -> Result<Self> {
+ let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
+ let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
+ let scale = 1. / ((dim / num_heads) as f64).sqrt();
+ Ok(Self {
+ qkv,
+ proj,
+ num_heads,
+ scale,
+ })
+ }
+}
+
+impl Module for Attention {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b, n, c) = xs.dims3()?;
+ let qkv = self
+ .qkv
+ .forward(xs)?
+ .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
+ .transpose(1, 2)? // 02134
+ .transpose(0, 1)? // 20134
+ .transpose(2, 3)?; // 20314
+ let q = (qkv.i(0)? * self.scale)?;
+ let k = qkv.i(1)?;
+ let v = qkv.i(2)?;
+ let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
+ let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
+ self.proj.forward(&attn)
+ }
+}
+
+#[derive(Debug)]
+struct LayerScale {
+ gamma: Tensor,
+}
+
+impl LayerScale {
+ fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
+ let gamma = vb.get(dim, "gamma")?;
+ Ok(Self { gamma })
+ }
+}
+
+impl Module for LayerScale {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.broadcast_mul(&self.gamma)
+ }
+}
+
+#[derive(Debug)]
+struct Mlp {
+ fc1: Linear,
+ fc2: Linear,
+}
+
+impl Mlp {
+ fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
+ let out_features = in_features;
+ let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
+ let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
+ Ok(Self { fc1, fc2 })
+ }
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.fc1.forward(xs)?.gelu()?;
+ self.fc2.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+struct Block {
+ norm1: LayerNorm,
+ attn: Attention,
+ ls1: LayerScale,
+ norm2: LayerNorm,
+ mlp: Mlp,
+ ls2: LayerScale,
+}
+
+impl Block {
+ fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
+ let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
+ let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
+ let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
+ let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
+ let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
+ let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
+ Ok(Self {
+ norm1,
+ attn,
+ ls1,
+ norm2,
+ mlp,
+ ls2,
+ })
+ }
+}
+
+impl Module for Block {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self
+ .ls1
+ .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
+ let xs = (xs + residual)?;
+ let residual = &xs;
+ let xs = self
+ .ls2
+ .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug)]
+struct PatchEmbed {
+ proj: candle_nn::Conv2d,
+ patch_size: (usize, usize),
+ num_patches: usize,
+}
+
+impl PatchEmbed {
+ fn new(
+ vb: VarBuilder,
+ img_size: usize,
+ patch_size: usize,
+ in_chans: usize,
+ embed_dim: usize,
+ ) -> Result<Self> {
+ let config = candle_nn::Conv2dConfig {
+ stride: patch_size,
+ ..Default::default()
+ };
+ let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
+ let num_patches = (img_size / patch_size) * (img_size / patch_size);
+ Ok(Self {
+ proj,
+ patch_size: (patch_size, patch_size),
+ num_patches,
+ })
+ }
+}
+
+impl Module for PatchEmbed {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (_b, _c, h, w) = xs.dims4()?;
+ let (patch_h, patch_w) = self.patch_size;
+ if (h % patch_h) != 0 {
+ candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
+ }
+ if (w % patch_w) != 0 {
+ candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
+ }
+ let xs = self.proj.forward(xs)?;
+ let (b, c, h, w) = xs.dims4()?;
+ // flatten embeddings.
+ xs.reshape((b, c, h * w))?.transpose(1, 2)
+ }
+}
+
+#[derive(Debug)]
+pub struct DinoVisionTransformer {
+ patch_embed: PatchEmbed,
+ cls_token: Tensor,
+ pos_embed: Tensor,
+ blocks: Vec<Block>,
+ norm: LayerNorm,
+ head: Linear,
+}
+
+impl DinoVisionTransformer {
+ pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
+ let patch_embed =
+ PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
+ let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
+ let num_tokens = 1;
+ let pos_embed = vb.get(
+ (1, patch_embed.num_patches + num_tokens, embed_dim),
+ "pos_embed",
+ )?;
+ let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
+ let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
+ let vb_b = vb.pp("blocks");
+ let blocks = (0..depth)
+ .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ patch_embed,
+ cls_token,
+ pos_embed,
+ blocks,
+ norm,
+ head,
+ })
+ }
+
+ fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
+ let npatch = xs.dim(1)? - 1;
+ let n = self.pos_embed.dim(1)? - 1;
+ let sqrt_n = (n as f64).sqrt();
+ if npatch == n && w == h {
+ return Ok(xs.clone());
+ }
+ let class_pos_embed = self.pos_embed.i((.., ..1))?;
+ let patch_pos_embed = self.pos_embed.i((.., 1..))?;
+ let dim = xs.dim(D::Minus1)?;
+ let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
+ let patch_pos_embed = patch_pos_embed
+ .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
+ .transpose(2, 3)?
+ .transpose(1, 2)?;
+ // This uses bicubic interpolation in the original implementation.
+ let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
+ let el_count = patch_pos_embed.shape().elem_count();
+ let patch_pos_embed =
+ patch_pos_embed
+ .transpose(1, 2)?
+ .transpose(2, 3)?
+ .reshape((1, el_count / dim, dim))?;
+ Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
+ }
+
+ fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
+ let (_b, _nc, w, h) = xs.dims4()?;
+ let xs = self.patch_embed.forward(xs)?;
+ let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
+ &xs + &self.interpolate_pos_encoding(&xs, w, h)?
+ }
+}
+
+impl Module for DinoVisionTransformer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.prepare_tokens_with_mask(xs)?;
+ for blk in self.blocks.iter() {
+ xs = blk.forward(&xs)?
+ }
+ let xs = self.norm.forward(&xs)?;
+ let xs_norm_clstoken = xs.i((.., 0))?;
+ let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
+ let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
+ self.head.forward(&xs)
+ }
+}
+
+pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
+ DinoVisionTransformer::new(vb, 12, 384, 6)
+}
diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs
new file mode 100644
index 00000000..ab51c76d
--- /dev/null
+++ b/candle-transformers/src/models/efficientnet.rs
@@ -0,0 +1,331 @@
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+use nn::{Module, VarBuilder};
+
+// Based on the Python version from torchvision.
+// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
+#[derive(Debug, Clone, Copy)]
+pub struct MBConvConfig {
+ expand_ratio: f64,
+ kernel: usize,
+ stride: usize,
+ input_channels: usize,
+ out_channels: usize,
+ num_layers: usize,
+}
+
+fn make_divisible(v: f64, divisor: usize) -> usize {
+ let min_value = divisor;
+ let new_v = usize::max(
+ min_value,
+ (v + divisor as f64 * 0.5) as usize / divisor * divisor,
+ );
+ if (new_v as f64) < 0.9 * v {
+ new_v + divisor
+ } else {
+ new_v
+ }
+}
+
+fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
+ let bneck_conf = |e, k, s, i, o, n| {
+ let input_channels = make_divisible(i as f64 * width_mult, 8);
+ let out_channels = make_divisible(o as f64 * width_mult, 8);
+ let num_layers = (n as f64 * depth_mult).ceil() as usize;
+ MBConvConfig {
+ expand_ratio: e,
+ kernel: k,
+ stride: s,
+ input_channels,
+ out_channels,
+ num_layers,
+ }
+ };
+ vec![
+ bneck_conf(1., 3, 1, 32, 16, 1),
+ bneck_conf(6., 3, 2, 16, 24, 2),
+ bneck_conf(6., 5, 2, 24, 40, 2),
+ bneck_conf(6., 3, 2, 40, 80, 3),
+ bneck_conf(6., 5, 1, 80, 112, 3),
+ bneck_conf(6., 5, 2, 112, 192, 4),
+ bneck_conf(6., 3, 1, 192, 320, 1),
+ ]
+}
+
+impl MBConvConfig {
+ pub fn b0() -> Vec<Self> {
+ bneck_confs(1.0, 1.0)
+ }
+ pub fn b1() -> Vec<Self> {
+ bneck_confs(1.0, 1.1)
+ }
+ pub fn b2() -> Vec<Self> {
+ bneck_confs(1.1, 1.2)
+ }
+ pub fn b3() -> Vec<Self> {
+ bneck_confs(1.2, 1.4)
+ }
+ pub fn b4() -> Vec<Self> {
+ bneck_confs(1.4, 1.8)
+ }
+ pub fn b5() -> Vec<Self> {
+ bneck_confs(1.6, 2.2)
+ }
+ pub fn b6() -> Vec<Self> {
+ bneck_confs(1.8, 2.6)
+ }
+ pub fn b7() -> Vec<Self> {
+ bneck_confs(2.0, 3.1)
+ }
+}
+
+/// Conv2D with same padding.
+#[derive(Debug)]
+struct Conv2DSame {
+ conv2d: nn::Conv2d,
+ s: usize,
+ k: usize,
+}
+
+impl Conv2DSame {
+ fn new(
+ vb: VarBuilder,
+ i: usize,
+ o: usize,
+ k: usize,
+ stride: usize,
+ groups: usize,
+ bias: bool,
+ ) -> Result<Self> {
+ let conv_config = nn::Conv2dConfig {
+ stride,
+ groups,
+ ..Default::default()
+ };
+ let conv2d = if bias {
+ nn::conv2d(i, o, k, conv_config, vb)?
+ } else {
+ nn::conv2d_no_bias(i, o, k, conv_config, vb)?
+ };
+ Ok(Self {
+ conv2d,
+ s: stride,
+ k,
+ })
+ }
+}
+
+impl Module for Conv2DSame {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let s = self.s;
+ let k = self.k;
+ let (_, _, ih, iw) = xs.dims4()?;
+ let oh = (ih + s - 1) / s;
+ let ow = (iw + s - 1) / s;
+ let pad_h = usize::max((oh - 1) * s + k - ih, 0);
+ let pad_w = usize::max((ow - 1) * s + k - iw, 0);
+ if pad_h > 0 || pad_w > 0 {
+ let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
+ let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
+ self.conv2d.forward(&xs)
+ } else {
+ self.conv2d.forward(xs)
+ }
+ }
+}
+
+#[derive(Debug)]
+struct ConvNormActivation {
+ conv2d: Conv2DSame,
+ bn2d: nn::BatchNorm,
+ activation: bool,
+}
+
+impl ConvNormActivation {
+ fn new(
+ vb: VarBuilder,
+ i: usize,
+ o: usize,
+ k: usize,
+ stride: usize,
+ groups: usize,
+ ) -> Result<Self> {
+ let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
+ let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
+ Ok(Self {
+ conv2d,
+ bn2d,
+ activation: true,
+ })
+ }
+
+ fn no_activation(self) -> Self {
+ Self {
+ activation: false,
+ ..self
+ }
+ }
+}
+
+impl Module for ConvNormActivation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.conv2d.forward(xs)?;
+ let xs = self.bn2d.forward(&xs)?;
+ if self.activation {
+ swish(&xs)
+ } else {
+ Ok(xs)
+ }
+ }
+}
+
+#[derive(Debug)]
+struct SqueezeExcitation {
+ fc1: Conv2DSame,
+ fc2: Conv2DSame,
+}
+
+impl SqueezeExcitation {
+ fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
+ let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
+ let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
+ Ok(Self { fc1, fc2 })
+ }
+}
+
+impl Module for SqueezeExcitation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ // equivalent to adaptive_avg_pool2d([1, 1])
+ let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
+ let xs = self.fc1.forward(&xs)?;
+ let xs = swish(&xs)?;
+ let xs = self.fc2.forward(&xs)?;
+ let xs = nn::ops::sigmoid(&xs)?;
+ residual.broadcast_mul(&xs)
+ }
+}
+
+#[derive(Debug)]
+struct MBConv {
+ expand_cna: Option<ConvNormActivation>,
+ depthwise_cna: ConvNormActivation,
+ squeeze_excitation: SqueezeExcitation,
+ project_cna: ConvNormActivation,
+ config: MBConvConfig,
+}
+
+impl MBConv {
+ fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
+ let vb = vb.pp("block");
+ let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
+ let expand_cna = if exp != c.input_channels {
+ Some(ConvNormActivation::new(
+ vb.pp("0"),
+ c.input_channels,
+ exp,
+ 1,
+ 1,
+ 1,
+ )?)
+ } else {
+ None
+ };
+ let start_index = if expand_cna.is_some() { 1 } else { 0 };
+ let depthwise_cna =
+ ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
+ let squeeze_channels = usize::max(1, c.input_channels / 4);
+ let squeeze_excitation =
+ SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
+ let project_cna =
+ ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
+ .no_activation();
+ Ok(Self {
+ expand_cna,
+ depthwise_cna,
+ squeeze_excitation,
+ project_cna,
+ config: c,
+ })
+ }
+}
+
+impl Module for MBConv {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let use_res_connect =
+ self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
+ let ys = match &self.expand_cna {
+ Some(expand_cna) => expand_cna.forward(xs)?,
+ None => xs.clone(),
+ };
+ let ys = self.depthwise_cna.forward(&ys)?;
+ let ys = self.squeeze_excitation.forward(&ys)?;
+ let ys = self.project_cna.forward(&ys)?;
+ if use_res_connect {
+ ys + xs
+ } else {
+ Ok(ys)
+ }
+ }
+}
+
+fn swish(s: &Tensor) -> Result<Tensor> {
+ s * nn::ops::sigmoid(s)?
+}
+
+#[derive(Debug)]
+pub struct EfficientNet {
+ init_cna: ConvNormActivation,
+ blocks: Vec<MBConv>,
+ final_cna: ConvNormActivation,
+ classifier: nn::Linear,
+}
+
+impl EfficientNet {
+ pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
+ let f_p = p.pp("features");
+ let first_in_c = configs[0].input_channels;
+ let last_out_c = configs.last().unwrap().out_channels;
+ let final_out_c = 4 * last_out_c;
+ let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
+ let nconfigs = configs.len();
+ let mut blocks = vec![];
+ for (index, cnf) in configs.into_iter().enumerate() {
+ let f_p = f_p.pp(index + 1);
+ for r_index in 0..cnf.num_layers {
+ let cnf = if r_index == 0 {
+ cnf
+ } else {
+ MBConvConfig {
+ input_channels: cnf.out_channels,
+ stride: 1,
+ ..cnf
+ }
+ };
+ blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
+ }
+ }
+ let final_cna =
+ ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
+ let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
+ Ok(Self {
+ init_cna,
+ blocks,
+ final_cna,
+ classifier,
+ })
+ }
+}
+
+impl Module for EfficientNet {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.init_cna.forward(xs)?;
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ let xs = self.final_cna.forward(&xs)?;
+ // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
+ let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
+ self.classifier.forward(&xs)
+ }
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 1b3dcf25..76e13b2a 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -1,5 +1,9 @@
pub mod bert;
pub mod bigcode;
+pub mod dinov2;
+pub mod efficientnet;
pub mod falcon;
pub mod llama;
+pub mod quantized_llama;
+pub mod segment_anything;
pub mod whisper;
diff --git a/candle-examples/examples/quantized/model.rs b/candle-transformers/src/models/quantized_llama.rs
index da0bd0b0..da0bd0b0 100644
--- a/candle-examples/examples/quantized/model.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-transformers/src/models/segment_anything/image_encoder.rs
index 76cd15d0..0b313830 100644
--- a/candle-examples/examples/segment-anything/model_image_encoder.rs
+++ b/candle-transformers/src/models/segment_anything/image_encoder.rs
@@ -100,8 +100,8 @@ impl candle::CustomOp3 for Add3 {
#[derive(Debug)]
struct Attention {
- qkv: crate::Linear,
- proj: crate::Linear,
+ qkv: super::Linear,
+ proj: super::Linear,
num_heads: usize,
scale: f64,
rel_pos_hw: Option<(Tensor, Tensor)>,
@@ -124,8 +124,8 @@ impl Attention {
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
- let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
- let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
+ let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
+ let proj = super::linear(vb.pp("proj"), dim, dim, true)?;
let head_dim = dim / num_heads;
let scale = 1. / (head_dim as f64).sqrt();
let rel_pos_hw = if use_rel_pos {
@@ -251,7 +251,7 @@ struct Block {
norm1: LayerNorm,
attn: Attention,
norm2: LayerNorm,
- mlp: crate::MlpBlock,
+ mlp: super::MlpBlock,
window_size: usize,
span: tracing::Span,
}
@@ -281,7 +281,7 @@ impl Block {
input_size_attn,
vb.pp("attn"),
)?;
- let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
+ let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
let span = tracing::span!(tracing::Level::TRACE, "ie-block");
Ok(Self {
norm1,
@@ -375,9 +375,9 @@ pub struct ImageEncoderViT {
patch_embed: PatchEmbed,
blocks: Vec<Block>,
neck_conv1: candle_nn::Conv2d,
- neck_ln1: crate::LayerNorm2d,
+ neck_ln1: super::LayerNorm2d,
neck_conv2: candle_nn::Conv2d,
- neck_ln2: crate::LayerNorm2d,
+ neck_ln2: super::LayerNorm2d,
pos_embed: Option<Tensor>,
span: tracing::Span,
}
@@ -433,13 +433,13 @@ impl ImageEncoderViT {
Default::default(),
vb.pp("neck.0"),
)?;
- let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
+ let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
let cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
- let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
+ let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
let pos_embed = if use_abs_pos {
let p = vb.get(
(1, img_size / patch_size, img_size / patch_size, embed_dim),
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-transformers/src/models/segment_anything/mask_decoder.rs
index c02b44a7..2a91cd44 100644
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ b/candle-transformers/src/models/segment_anything/mask_decoder.rs
@@ -1,11 +1,11 @@
use candle::{IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
-use crate::model_transformer::TwoWayTransformer;
+use super::transformer::TwoWayTransformer;
#[derive(Debug)]
struct MlpMaskDecoder {
- layers: Vec<crate::Linear>,
+ layers: Vec<super::Linear>,
sigmoid_output: bool,
span: tracing::Span,
}
@@ -28,7 +28,7 @@ impl MlpMaskDecoder {
} else {
hidden_dim
};
- let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
+ let layer = super::linear(vb.pp(i), in_dim, out_dim, true)?;
layers.push(layer)
}
let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
@@ -64,7 +64,7 @@ pub struct MaskDecoder {
mask_tokens: candle_nn::Embedding,
iou_prediction_head: MlpMaskDecoder,
output_upscaling_conv1: candle_nn::ConvTranspose2d,
- output_upscaling_ln: crate::LayerNorm2d,
+ output_upscaling_ln: super::LayerNorm2d,
output_upscaling_conv2: candle_nn::ConvTranspose2d,
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
@@ -104,7 +104,7 @@ impl MaskDecoder {
vb.pp("output_upscaling.0"),
)?;
let output_upscaling_ln =
- crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
+ super::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
transformer_dim / 4,
transformer_dim / 8,
diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs
new file mode 100644
index 00000000..c29db70a
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/mod.rs
@@ -0,0 +1,100 @@
+use candle::{Result, Tensor};
+use candle_nn::{Module, VarBuilder};
+
+pub mod image_encoder;
+pub mod mask_decoder;
+pub mod prompt_encoder;
+pub mod sam;
+pub mod tiny_vit;
+pub mod transformer;
+
+pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
+ let inner = if bias {
+ candle_nn::linear(in_dim, out_dim, vb)?
+ } else {
+ candle_nn::linear_no_bias(in_dim, out_dim, vb)?
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { inner, span })
+}
+
+#[derive(Debug)]
+pub struct LayerNorm2d {
+ weight: Tensor,
+ bias: Tensor,
+ num_channels: usize,
+ eps: f64,
+}
+
+impl LayerNorm2d {
+ pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let weight = vb.get(num_channels, "weight")?;
+ let bias = vb.get(num_channels, "bias")?;
+ Ok(Self {
+ weight,
+ bias,
+ num_channels,
+ eps,
+ })
+ }
+}
+
+impl Module for LayerNorm2d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let u = xs.mean_keepdim(1)?;
+ let xs = xs.broadcast_sub(&u)?;
+ let s = xs.sqr()?.mean_keepdim(1)?;
+ let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
+ xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
+ .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
+ }
+}
+
+#[derive(Debug)]
+pub struct MlpBlock {
+ lin1: Linear,
+ lin2: Linear,
+ activation: candle_nn::Activation,
+ span: tracing::Span,
+}
+
+impl MlpBlock {
+ pub fn new(
+ embedding_dim: usize,
+ mlp_dim: usize,
+ activation: candle_nn::Activation,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
+ let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
+ let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
+ Ok(Self {
+ lin1,
+ lin2,
+ activation,
+ span,
+ })
+ }
+}
+
+impl Module for MlpBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.lin1)?
+ .apply(&self.activation)?
+ .apply(&self.lin2)
+ }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Module for Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs
index 7bbe8419..9d0074b1 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs
@@ -56,9 +56,9 @@ pub struct PromptEncoder {
point_embeddings: Vec<candle_nn::Embedding>,
not_a_point_embed: candle_nn::Embedding,
mask_downscaling_conv1: candle_nn::Conv2d,
- mask_downscaling_ln1: crate::LayerNorm2d,
+ mask_downscaling_ln1: super::LayerNorm2d,
mask_downscaling_conv2: candle_nn::Conv2d,
- mask_downscaling_ln2: crate::LayerNorm2d,
+ mask_downscaling_ln2: super::LayerNorm2d,
mask_downscaling_conv3: candle_nn::Conv2d,
no_mask_embed: candle_nn::Embedding,
image_embedding_size: (usize, usize),
@@ -100,9 +100,9 @@ impl PromptEncoder {
vb.pp("mask_downscaling.6"),
)?;
let mask_downscaling_ln1 =
- crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
+ super::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
let mask_downscaling_ln2 =
- crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
+ super::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
let vb_e = vb.pp("point_embeddings");
for i in 0..num_points_embeddings {
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-transformers/src/models/segment_anything/sam.rs
index b1a81af6..c40473e3 100644
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ b/candle-transformers/src/models/segment_anything/sam.rs
@@ -1,10 +1,10 @@
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
-use crate::model_image_encoder::ImageEncoderViT;
-use crate::model_mask_decoder::MaskDecoder;
-use crate::model_prompt_encoder::PromptEncoder;
-use crate::model_tiny_vit::{tiny_vit_5m, TinyViT};
+use super::image_encoder::ImageEncoderViT;
+use super::mask_decoder::MaskDecoder;
+use super::prompt_encoder::PromptEncoder;
+use super::tiny_vit::{tiny_vit_5m, TinyViT};
const PROMPT_EMBED_DIM: usize = 256;
pub const IMAGE_SIZE: usize = 1024;
@@ -186,7 +186,7 @@ impl Sam {
img: &Tensor,
cb: CropBox,
point_grids: &[(f64, f64)],
- ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
+ ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
// Crop the image and calculate embeddings.
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
let img = self.preprocess(&img)?.unsqueeze(0)?;
@@ -259,7 +259,7 @@ impl Sam {
let min_max_x = min_max_indexes(&low_res_mask_per_x);
let min_max_y = min_max_indexes(&low_res_mask_per_y);
if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
- let bbox = candle_examples::object_detection::Bbox {
+ let bbox = crate::object_detection::Bbox {
xmin: x0 as f32,
ymin: y0 as f32,
xmax: x1 as f32,
@@ -277,7 +277,7 @@ impl Sam {
let mut bboxes = vec![bboxes];
// Remove duplicates within this crop.
- candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
+ crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
// TODO: Return to the original image frame.
Ok(bboxes.remove(0))
@@ -290,7 +290,7 @@ impl Sam {
crop_n_layer: usize,
crop_overlap_ratio: f64,
crop_n_points_downscale_factor: usize,
- ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
+ ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
let (_c, h, w) = img.dims3()?;
let point_grids = build_all_layer_point_grids(
points_per_side,
diff --git a/candle-examples/examples/segment-anything/model_tiny_vit.rs b/candle-transformers/src/models/segment_anything/tiny_vit.rs
index ff076773..cd2936ab 100644
--- a/candle-examples/examples/segment-anything/model_tiny_vit.rs
+++ b/candle-transformers/src/models/segment_anything/tiny_vit.rs
@@ -215,16 +215,16 @@ impl Module for ConvLayer {
#[derive(Debug)]
struct Mlp {
norm: candle_nn::LayerNorm,
- fc1: crate::Linear,
- fc2: crate::Linear,
+ fc1: super::Linear,
+ fc2: super::Linear,
span: tracing::Span,
}
impl Mlp {
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
- let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?;
- let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?;
+ let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?;
+ let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp");
Ok(Self {
norm,
@@ -248,8 +248,8 @@ impl Module for Mlp {
#[derive(Debug)]
struct Attention {
norm: candle_nn::LayerNorm,
- qkv: crate::Linear,
- proj: crate::Linear,
+ qkv: super::Linear,
+ proj: super::Linear,
ab: Tensor,
key_dim: usize,
num_heads: usize,
@@ -275,8 +275,8 @@ impl Attention {
let nh_kd = key_dim * num_heads;
let h = dh + nh_kd * 2;
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
- let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?;
- let proj = crate::linear(vb.pp("proj"), dh, dim, true)?;
+ let qkv = super::linear(vb.pp("qkv"), dim, h, true)?;
+ let proj = super::linear(vb.pp("proj"), dh, dim, true)?;
let points = (0..resolution.0)
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
@@ -526,9 +526,9 @@ pub struct TinyViT {
// norm_head: candle_nn::LayerNorm,
// head: candle_nn::Linear,
neck_conv1: candle_nn::Conv2d,
- neck_ln1: crate::LayerNorm2d,
+ neck_ln1: super::LayerNorm2d,
neck_conv2: candle_nn::Conv2d,
- neck_ln2: crate::LayerNorm2d,
+ neck_ln2: super::LayerNorm2d,
span: tracing::Span,
span_neck: tracing::Span,
}
@@ -578,13 +578,13 @@ impl TinyViT {
// let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
let neck_conv1 =
candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
- let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
+ let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
let cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
- let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
+ let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-transformers/src/models/segment_anything/transformer.rs
index e12aac08..80efb38c 100644
--- a/candle-examples/examples/segment-anything/model_transformer.rs
+++ b/candle-transformers/src/models/segment_anything/transformer.rs
@@ -68,7 +68,7 @@ struct TwoWayAttentionBlock {
norm1: LayerNorm,
cross_attn_token_to_image: Attention,
norm2: LayerNorm,
- mlp: crate::MlpBlock,
+ mlp: super::MlpBlock,
norm3: LayerNorm,
norm4: LayerNorm,
cross_attn_image_to_token: Attention,
@@ -100,7 +100,7 @@ impl TwoWayAttentionBlock {
2,
vb.pp("cross_attn_image_to_token"),
)?;
- let mlp = crate::MlpBlock::new(
+ let mlp = super::MlpBlock::new(
embedding_dim,
mlp_dim,
candle_nn::Activation::Relu,
diff --git a/candle-examples/src/object_detection.rs b/candle-transformers/src/object_detection.rs
index ce579316..ce579316 100644
--- a/candle-examples/src/object_detection.rs
+++ b/candle-transformers/src/object_detection.rs