diff options
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | candle-examples/examples/mobilenetv4/README.md | 18 | ||||
-rw-r--r-- | candle-examples/examples/mobilenetv4/main.rs | 106 | ||||
-rw-r--r-- | candle-examples/src/imagenet.rs | 30 | ||||
-rw-r--r-- | candle-transformers/src/models/mobilenetv4.rs | 800 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 |
6 files changed, 939 insertions, 18 deletions
@@ -236,7 +236,7 @@ If you have an addition to this list, please submit a pull request. - MetaVoice-1B, text-to-speech model. - Computer Vision Models. - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT, - ConvNeXTv2, MobileOne, EfficientVit (MSRA). + ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4. - yolo-v3, yolo-v8. - Segment-Anything Model (SAM). - SegFormer. diff --git a/candle-examples/examples/mobilenetv4/README.md b/candle-examples/examples/mobilenetv4/README.md new file mode 100644 index 00000000..c8356466 --- /dev/null +++ b/candle-examples/examples/mobilenetv4/README.md @@ -0,0 +1,18 @@ +# candle-mobilenetv4 + +[MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518) +This candle implementation uses pre-trained MobileNetV4 models from timm for inference. +The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes. + +## Running an example + +``` +$ cargo run --example mobilenetv4 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which medium +loaded image Tensor[dims 3, 256, 256; f32] +model built +unicycle, monocycle : 20.18% +mountain bike, all-terrain bike, off-roader: 19.77% +bicycle-built-for-two, tandem bicycle, tandem: 15.91% +crash helmet : 1.15% +tricycle, trike, velocipede: 0.67% +``` diff --git a/candle-examples/examples/mobilenetv4/main.rs b/candle-examples/examples/mobilenetv4/main.rs new file mode 100644 index 00000000..26c0dad9 --- /dev/null +++ b/candle-examples/examples/mobilenetv4/main.rs @@ -0,0 +1,106 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::mobilenetv4; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + Small, + Medium, + Large, + HybridMedium, + HybridLarge, +} + +impl Which { + fn model_filename(&self) -> String { + let name = match self { + Self::Small => "conv_small.e2400_r224", + Self::Medium => "conv_medium.e500_r256", + Self::HybridMedium => "hybrid_medium.ix_e550_r256", + Self::Large => "conv_large.e600_r384", + Self::HybridLarge => "hybrid_large.ix_e600_r384", + }; + format!("timm/mobilenetv4_{}_in1k", name) + } + + fn resolution(&self) -> u32 { + match self { + Self::Small => 224, + Self::Medium => 256, + Self::HybridMedium => 256, + Self::Large => 384, + Self::HybridLarge => 384, + } + } + fn config(&self) -> mobilenetv4::Config { + match self { + Self::Small => mobilenetv4::Config::small(), + Self::Medium => mobilenetv4::Config::medium(), + Self::HybridMedium => mobilenetv4::Config::hybrid_medium(), + Self::Large => mobilenetv4::Config::large(), + Self::HybridLarge => mobilenetv4::Config::hybrid_large(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option<String>, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(value_enum, long, default_value_t=Which::Small)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image(args.image, args.which.resolution())? + .to_device(&device)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let model_name = args.which.model_filename(); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = mobilenetv4::mobilenetv4(&args.which.config(), 1000, vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::<f32>()?; + let mut prs = prs.iter().enumerate().collect::<Vec<_>>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs index 781dcd4f..6b079870 100644 --- a/candle-examples/src/imagenet.rs +++ b/candle-examples/src/imagenet.rs @@ -1,15 +1,16 @@ use candle::{Device, Result, Tensor}; -/// Loads an image from disk using the image crate, this returns a tensor with shape -/// (3, 224, 224). imagenet normalization is applied. -pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { +/// Loads an image from disk using the image crate at the requested resolution. +// This returns a tensor with shape (3, res, res). imagenet normalization is applied. +pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> { let img = image::io::Reader::open(p)? .decode() .map_err(candle::Error::wrap)? - .resize_to_fill(224, 224, image::imageops::FilterType::Triangle); + .resize_to_fill(res, res, image::imageops::FilterType::Triangle); let img = img.to_rgb8(); let data = img.into_raw(); - let data = Tensor::from_vec(data, (224, 224, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)? + .permute((2, 0, 1))?; let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; (data.to_dtype(candle::DType::F32)? / 255.)? @@ -18,21 +19,16 @@ pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { } /// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 224, 224). imagenet normalization is applied. +pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { + load_image(p, 224) +} + +/// Loads an image from disk using the image crate, this returns a tensor with shape /// (3, 518, 518). imagenet normalization is applied. /// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens). pub fn load_image518<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { - let img = image::io::Reader::open(p)? - .decode() - .map_err(candle::Error::wrap)? - .resize_to_fill(518, 518, image::imageops::FilterType::Triangle); - let img = img.to_rgb8(); - let data = img.into_raw(); - let data = Tensor::from_vec(data, (518, 518, 3), &Device::Cpu)?.permute((2, 0, 1))?; - let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; - let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; - (data.to_dtype(candle::DType::F32)? / 255.)? - .broadcast_sub(&mean)? - .broadcast_div(&std) + load_image(p, 518) } pub const CLASS_COUNT: i64 = 1000; diff --git a/candle-transformers/src/models/mobilenetv4.rs b/candle-transformers/src/models/mobilenetv4.rs new file mode 100644 index 00000000..7cbae7c3 --- /dev/null +++ b/candle-transformers/src/models/mobilenetv4.rs @@ -0,0 +1,800 @@ +//! MobileNet-v4 inference implementation based on timm. +//! +//! See "MobileNetV4 - Universal Models for the Mobile Ecosystem" +//! https://arxiv.org/abs/2404.10518 +//! +//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mobilenetv3.py + +use candle::{Result, Tensor, D}; +use candle_nn::{ + batch_norm, conv2d_no_bias, linear, ops::softmax, Activation, Conv2dConfig, Func, VarBuilder, +}; + +#[derive(Clone, Debug)] +enum BlockType { + Convolutional { + out_channels: usize, + kernel: usize, + stride: usize, + }, + UniversalBottleneck { + out_channels: usize, + start_kernel: usize, + mid_kernel: usize, + stride: usize, + expand: usize, + }, + EdgeResidual { + out_channels: usize, + kernel: usize, + stride: usize, + expand: usize, + }, + Attention { + out_channels: usize, + heads: usize, + kernel: usize, + stride: usize, + kv_dim: usize, + kv_stride: usize, + }, +} + +#[derive(Clone, Debug)] +pub struct Config { + stem_dim: usize, + activation: Activation, + stages: [Vec<BlockType>; 5], +} + +#[rustfmt::skip] +impl Config { + pub fn small() -> Self { + Self { + stem_dim: 32, + activation: Activation::Relu, + stages: [ + vec![ + BlockType::Convolutional { out_channels: 32, kernel: 3, stride: 2}, + BlockType::Convolutional { out_channels: 32, kernel: 1, stride: 1}, + ], + vec![ + BlockType::Convolutional { out_channels: 96, kernel: 3, stride: 2}, + BlockType::Convolutional { out_channels: 64, kernel: 1, stride: 1}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 3}, + BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2}, + BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2}, + BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2}, + BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2}, + BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 3, mid_kernel: 3, stride: 2, expand: 6}, + BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 5, stride: 1, expand: 3}, + BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 4}, + ], + vec![ + BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1}, + ], + ], + } + } + + pub fn medium() -> Self { + Self { + stem_dim: 32, + activation: Activation::Relu, + stages: [ + vec![ + BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 2}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 6}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 6}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 2}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 2}, + + ], + vec![ + BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1}, + ], + ], + } + } + + pub fn hybrid_medium() -> Self { + Self { + stem_dim: 32, + activation: Activation::Relu, + stages: [ + vec![ + BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 2}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 6}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4}, + ], + + vec![ + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 6}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 2}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + ], + vec![ + BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1}, + ], + ], + } + } + + pub fn large() -> Self { + Self { + stem_dim: 24, + activation: Activation::Relu, + stages: [ + vec![ + BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + ], + vec![ + BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1}, + ], + ], + } + } + + pub fn hybrid_large() -> Self { + Self { + stem_dim: 24, + activation: Activation::Gelu, + stages: [ + vec![ + BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + ], + vec![ + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48}, + BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4}, + ], + + vec![ + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64}, + BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4}, + ], + vec![ + BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1}, + ], + ], + } + } +} + +fn depthwise_conv( + channels: usize, + kernel: usize, + stride: usize, + padding: usize, + vb: VarBuilder, +) -> Result<Func<'static>> { + let conv2d_cfg = Conv2dConfig { + stride, + padding, + groups: channels, + ..Default::default() + }; + + let bn = batch_norm(channels, 1e-5, vb.pp("bn"))?; + let conv = conv2d_no_bias(channels, channels, kernel, conv2d_cfg, vb.pp("conv"))?; + + Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false))) +} + +fn pointwise_conv( + in_channels: usize, + out_channels: usize, + vb: VarBuilder, +) -> Result<Func<'static>> { + let conv2d_cfg = Conv2dConfig { + ..Default::default() + }; + + let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?; + let conv = conv2d_no_bias(in_channels, out_channels, 1, conv2d_cfg, vb.pp("conv"))?; + + Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false))) +} + +//Universal block that uses two pointwise convolutions and all combinations of two depthwise convolutions. +#[allow(clippy::too_many_arguments)] +fn universal_inverted_bottleneck_block( + cfg: &Config, + in_channels: usize, + out_channels: usize, + expand: usize, + start_kernel: usize, + mid_kernel: usize, + stride: usize, + vb: VarBuilder, +) -> Result<Func<'static>> { + let act = cfg.activation; + let skip_connection = (in_channels == out_channels) && (stride == 1); + + let dw_start_stride = if mid_kernel > 0 { 1 } else { stride }; + let dw_start = depthwise_conv( + in_channels, + start_kernel, + dw_start_stride, + start_kernel / 2, + vb.pp("dw_start"), + ); + let pw_exp = pointwise_conv(in_channels, in_channels * expand, vb.pp("pw_exp"))?; + let dw_mid = depthwise_conv( + in_channels * expand, + mid_kernel, + stride, + mid_kernel / 2, + vb.pp("dw_mid"), + ); + let pw_proj = pointwise_conv(in_channels * expand, out_channels, vb.pp("pw_proj"))?; + + let gamma = vb.get(out_channels, "layer_scale.gamma"); + + Ok(Func::new(move |xs| { + let residual = xs.clone(); + + let mut xs = xs.clone(); + + if let Ok(f) = &dw_start { + xs = xs.apply(f)?; + } + + xs = xs.apply(&pw_exp)?.apply(&act)?; + + if let Ok(f) = &dw_mid { + xs = xs.apply(f)?.apply(&act)?; + } + + xs = xs.apply(&pw_proj)?; + + if let Ok(g) = &gamma { + xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?; + }; + + if skip_connection { + xs = (xs + residual)?; + } + + Ok(xs) + })) +} + +// Convolutional block including norm and activation. +fn conv_block( + cfg: &Config, + in_channels: usize, + out_channels: usize, + kernel: usize, + stride: usize, + vb: VarBuilder, +) -> Result<Func<'static>> { + let conv2d_cfg = Conv2dConfig { + stride, + padding: kernel / 2, + ..Default::default() + }; + + let act = cfg.activation; + let bn = batch_norm(out_channels, 1e-5, vb.pp("bn1"))?; + let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?; + + Ok(Func::new(move |xs| { + xs.apply(&conv)?.apply_t(&bn, false)?.apply(&act) + })) +} + +fn edge_residual_block( + cfg: &Config, + in_channels: usize, + out_channels: usize, + kernel: usize, + stride: usize, + expand: usize, + vb: VarBuilder, +) -> Result<Func<'static>> { + let conv_exp_cfg = Conv2dConfig { + stride, + padding: kernel / 2, + ..Default::default() + }; + + let conv_pwl_cfg = Conv2dConfig { + ..Default::default() + }; + + let act = cfg.activation; + let mid_channels = in_channels * expand; + let conv_exp = conv2d_no_bias( + in_channels, + mid_channels, + kernel, + conv_exp_cfg, + vb.pp("conv_exp"), + )?; + let bn1 = batch_norm(mid_channels, 1e-5, vb.pp("bn1"))?; + + let conv_pwl = conv2d_no_bias( + mid_channels, + out_channels, + 1, + conv_pwl_cfg, + vb.pp("conv_pwl"), + )?; + let bn2 = batch_norm(out_channels, 1e-5, vb.pp("bn2"))?; + + Ok(Func::new(move |xs| { + let xs = xs + .apply(&conv_exp)? + .apply_t(&bn1, false)? + .apply(&act)? + .apply(&conv_pwl)? + .apply_t(&bn2, false)?; + + Ok(xs) + })) +} + +fn reshape_kv(t: &Tensor) -> Result<Tensor> { + let d = t.dims4()?; + let t = t + .reshape((d.0, d.1, ()))? + .transpose(1, 2)? + .unsqueeze(1)? + .contiguous()?; + Ok(t) +} + +fn reshape_query(t: &Tensor, heads: usize, kv_dim: usize) -> Result<Tensor> { + let d = t.dims4()?; + + let t = t + .reshape((d.0, heads, kv_dim, ()))? + .transpose(D::Minus1, D::Minus2)? + .contiguous()?; + Ok(t) +} + +fn reshape_output(t: &Tensor, heads: usize, h: usize, w: usize) -> Result<Tensor> { + let d = t.dims4()?; + let t = t.transpose(1, 2)?; + let t = t + .reshape((d.0, h, w, d.3 * heads))? + .permute((0, 3, 1, 2))? + .contiguous()?; + Ok(t) +} + +// Mobile multi-query attention +#[allow(clippy::too_many_arguments)] +fn mqa_block( + in_channels: usize, + out_channels: usize, + heads: usize, + kernel: usize, + stride: usize, + kv_dim: usize, + kv_stride: usize, + vb: VarBuilder, +) -> Result<Func<'static>> { + let down_conv2d_cfg = Conv2dConfig { + stride: kv_stride, + padding: kernel / 2, + groups: in_channels, + ..Default::default() + }; + + let proj_conv2d_cfg = Conv2dConfig { + stride, + ..Default::default() + }; + + let skip_connection = (in_channels == out_channels) && (stride == 1); + let gamma = vb.get(out_channels, "layer_scale.gamma"); + let norm = batch_norm(out_channels, 1e-5, vb.pp("norm"))?; + let scale = (kv_dim as f64).powf(-0.5); + + let vb = vb.pp("attn"); + + let query_proj = conv2d_no_bias( + out_channels, + kv_dim * heads, + 1, + proj_conv2d_cfg, + vb.pp("query.proj"), + )?; + + let key_down_conv = conv2d_no_bias( + in_channels, + out_channels, + kernel, + down_conv2d_cfg, + vb.pp("key.down_conv"), + ); + let key_norm = batch_norm(out_channels, 1e-5, vb.pp("key.norm")); + + let key_proj = conv2d_no_bias(out_channels, kv_dim, 1, proj_conv2d_cfg, vb.pp("key.proj"))?; + + let value_down_conv = conv2d_no_bias( + in_channels, + out_channels, + kernel, + down_conv2d_cfg, + vb.pp("value.down_conv"), + ); + + let value_norm = batch_norm(out_channels, 1e-5, vb.pp("value.norm")); + let value_proj = conv2d_no_bias( + out_channels, + kv_dim, + 1, + proj_conv2d_cfg, + vb.pp("value.proj"), + )?; + + let output_proj = conv2d_no_bias( + kv_dim * heads, + out_channels, + 1, + proj_conv2d_cfg, + vb.pp("output.proj"), + )?; + + Ok(Func::new(move |xs| { + let (_, _, h, w) = xs.dims4()?; + + let residual = xs.clone(); + + let xs = xs.apply_t(&norm, false)?; + + // Query + let q = xs.apply(&query_proj)?; + + let q = reshape_query(&q, heads, kv_dim)?; + let q = (q * scale)?; + + // Keys + let mut k = xs.clone(); + + if let (Ok(kd), Ok(n)) = (&key_down_conv, &key_norm) { + k = k.apply(kd)?.apply_t(n, false)?; + } + + let k = k.apply(&key_proj)?; + + let k = reshape_kv(&k)?; + + // Value + let mut v = xs.clone(); + + if let (Ok(vd), Ok(n)) = (&value_down_conv, &value_norm) { + v = v.apply(vd)?; + v = v.apply_t(n, false)?; + } + + let v = v.apply(&value_proj)?; + let v = reshape_kv(&v)?; + + let attn = q.broadcast_matmul(&(k.transpose(D::Minus2, D::Minus1)?))?; + let attn = softmax(&attn, D::Minus1)?; + let o = attn.broadcast_matmul(&v)?; + + let o = reshape_output(&o, heads, h, w)?; + + let mut xs = o.apply(&output_proj)?; + + // Layer scale + + if let Ok(g) = &gamma { + xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?; + }; + + if skip_connection { + xs = (xs + residual)?; + } + Ok(xs) + })) +} + +// Stem. +fn mobilenetv4_stem(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> { + let conv2d_cfg = Conv2dConfig { + stride: 2, + padding: 1, + ..Default::default() + }; + + let act = cfg.activation; + let out_channels = cfg.stem_dim; + let bn = batch_norm(out_channels, 1e-5, vb.pp("bn1"))?; + let conv = conv2d_no_bias(3, out_channels, 3, conv2d_cfg, vb.pp("conv_stem"))?; + + Ok(Func::new(move |xs| { + let xs = xs.apply(&conv)?.apply_t(&bn, false)?.apply(&act)?; + Ok(xs) + })) +} + +// The blocks in all the 5 stages of the model. +fn mobilenetv4_blocks(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> { + let mut in_channels = cfg.stem_dim; + let mut blocks = Vec::new(); + + for stage in 0..5 { + let nblocks = cfg.stages[stage].len(); + + for block in 0..nblocks { + match cfg.stages[stage][block] { + BlockType::Convolutional { + out_channels, + kernel, + stride, + } => { + blocks.push(conv_block( + cfg, + in_channels, + out_channels, + kernel, + stride, + vb.pp(format!("{stage}.{block}")), + )?); + in_channels = out_channels; + } + + BlockType::EdgeResidual { + out_channels, + kernel, + stride, + expand, + } => { + blocks.push(edge_residual_block( + cfg, + in_channels, + out_channels, + kernel, + stride, + expand, + vb.pp(format!("{stage}.{block}")), + )?); + in_channels = out_channels; + } + + BlockType::UniversalBottleneck { + out_channels, + start_kernel, + mid_kernel, + stride, + expand, + } => { + blocks.push(universal_inverted_bottleneck_block( + cfg, + in_channels, + out_channels, + expand, + start_kernel, + mid_kernel, + stride, + vb.pp(format!("{stage}.{block}")), + )?); + in_channels = out_channels; + } + + BlockType::Attention { + out_channels, + heads, + kernel, + stride, + kv_dim, + kv_stride, + } => { + blocks.push(mqa_block( + in_channels, + out_channels, + heads, + kernel, + stride, + kv_dim, + kv_stride, + vb.pp(format!("{stage}.{block}")), + )?); + in_channels = out_channels; + } + } + } + } + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for block in blocks.iter() { + xs = xs.apply(block)? + } + Ok(xs) + })) +} + +// Classification head. +fn mobilenetv4_head( + cfg: &Config, + outputs: usize, + nclasses: usize, + vb: VarBuilder, +) -> Result<Func<'static>> { + let conv2d_cfg = Conv2dConfig { + ..Default::default() + }; + + let act = cfg.activation; + let conv = conv2d_no_bias(960, outputs, 1, conv2d_cfg, vb.pp("conv_head"))?; + let norm = batch_norm(outputs, 1e-5, vb.pp("norm_head"))?; + let cls = linear(outputs, nclasses, vb.pp("classifier"))?; + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + xs = xs.apply(&conv)?; + xs = xs.apply_t(&norm, false)?.apply(&act)?; + xs = xs.flatten_from(1)?; + xs = xs.apply(&cls)?; + Ok(xs) + })) +} + +// Build a mobilenetv4 model for a given configuration. +fn mobilenetv4_model( + cfg: &Config, + nclasses: Option<usize>, + vb: VarBuilder, +) -> Result<Func<'static>> { + let cls = match nclasses { + None => None, + Some(nclasses) => { + let outputs = 1280; + let head = mobilenetv4_head(cfg, outputs, nclasses, vb.clone())?; + Some(head) + } + }; + + let stem = mobilenetv4_stem(cfg, vb.clone())?; + + let blocks = mobilenetv4_blocks(cfg, vb.pp("blocks"))?; + + Ok(Func::new(move |xs| { + let xs = xs.apply(&stem)?.apply(&blocks)?; + let xs = xs.mean_keepdim(D::Minus1)?.mean_keepdim(D::Minus2)?; + match &cls { + None => Ok(xs), + Some(cls) => xs.apply(cls), + } + })) +} + +pub fn mobilenetv4(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> { + mobilenetv4_model(cfg, Some(nclasses), vb) +} + +pub fn mobilenetv4_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> { + mobilenetv4_model(cfg, None, vb) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index f5859a99..86a0ec08 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -28,6 +28,7 @@ pub mod metavoice; pub mod mistral; pub mod mixformer; pub mod mixtral; +pub mod mobilenetv4; pub mod mobileone; pub mod moondream; pub mod mpt; |