diff options
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/mobilenetv4/README.md | 18 | ||||
-rw-r--r-- | candle-examples/examples/mobilenetv4/main.rs | 106 | ||||
-rw-r--r-- | candle-examples/src/imagenet.rs | 30 |
3 files changed, 137 insertions, 17 deletions
diff --git a/candle-examples/examples/mobilenetv4/README.md b/candle-examples/examples/mobilenetv4/README.md new file mode 100644 index 00000000..c8356466 --- /dev/null +++ b/candle-examples/examples/mobilenetv4/README.md @@ -0,0 +1,18 @@ +# candle-mobilenetv4 + +[MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518) +This candle implementation uses pre-trained MobileNetV4 models from timm for inference. +The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes. + +## Running an example + +``` +$ cargo run --example mobilenetv4 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which medium +loaded image Tensor[dims 3, 256, 256; f32] +model built +unicycle, monocycle : 20.18% +mountain bike, all-terrain bike, off-roader: 19.77% +bicycle-built-for-two, tandem bicycle, tandem: 15.91% +crash helmet : 1.15% +tricycle, trike, velocipede: 0.67% +``` diff --git a/candle-examples/examples/mobilenetv4/main.rs b/candle-examples/examples/mobilenetv4/main.rs new file mode 100644 index 00000000..26c0dad9 --- /dev/null +++ b/candle-examples/examples/mobilenetv4/main.rs @@ -0,0 +1,106 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::mobilenetv4; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + Small, + Medium, + Large, + HybridMedium, + HybridLarge, +} + +impl Which { + fn model_filename(&self) -> String { + let name = match self { + Self::Small => "conv_small.e2400_r224", + Self::Medium => "conv_medium.e500_r256", + Self::HybridMedium => "hybrid_medium.ix_e550_r256", + Self::Large => "conv_large.e600_r384", + Self::HybridLarge => "hybrid_large.ix_e600_r384", + }; + format!("timm/mobilenetv4_{}_in1k", name) + } + + fn resolution(&self) -> u32 { + match self { + Self::Small => 224, + Self::Medium => 256, + Self::HybridMedium => 256, + Self::Large => 384, + Self::HybridLarge => 384, + } + } + fn config(&self) -> mobilenetv4::Config { + match self { + Self::Small => mobilenetv4::Config::small(), + Self::Medium => mobilenetv4::Config::medium(), + Self::HybridMedium => mobilenetv4::Config::hybrid_medium(), + Self::Large => mobilenetv4::Config::large(), + Self::HybridLarge => mobilenetv4::Config::hybrid_large(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option<String>, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(value_enum, long, default_value_t=Which::Small)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image(args.image, args.which.resolution())? + .to_device(&device)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let model_name = args.which.model_filename(); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = mobilenetv4::mobilenetv4(&args.which.config(), 1000, vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::<f32>()?; + let mut prs = prs.iter().enumerate().collect::<Vec<_>>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs index 781dcd4f..6b079870 100644 --- a/candle-examples/src/imagenet.rs +++ b/candle-examples/src/imagenet.rs @@ -1,15 +1,16 @@ use candle::{Device, Result, Tensor}; -/// Loads an image from disk using the image crate, this returns a tensor with shape -/// (3, 224, 224). imagenet normalization is applied. -pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { +/// Loads an image from disk using the image crate at the requested resolution. +// This returns a tensor with shape (3, res, res). imagenet normalization is applied. +pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> { let img = image::io::Reader::open(p)? .decode() .map_err(candle::Error::wrap)? - .resize_to_fill(224, 224, image::imageops::FilterType::Triangle); + .resize_to_fill(res, res, image::imageops::FilterType::Triangle); let img = img.to_rgb8(); let data = img.into_raw(); - let data = Tensor::from_vec(data, (224, 224, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)? + .permute((2, 0, 1))?; let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; (data.to_dtype(candle::DType::F32)? / 255.)? @@ -18,21 +19,16 @@ pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { } /// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 224, 224). imagenet normalization is applied. +pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { + load_image(p, 224) +} + +/// Loads an image from disk using the image crate, this returns a tensor with shape /// (3, 518, 518). imagenet normalization is applied. /// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens). pub fn load_image518<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { - let img = image::io::Reader::open(p)? - .decode() - .map_err(candle::Error::wrap)? - .resize_to_fill(518, 518, image::imageops::FilterType::Triangle); - let img = img.to_rgb8(); - let data = img.into_raw(); - let data = Tensor::from_vec(data, (518, 518, 3), &Device::Cpu)?.permute((2, 0, 1))?; - let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; - let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; - (data.to_dtype(candle::DType::F32)? / 255.)? - .broadcast_sub(&mean)? - .broadcast_div(&std) + load_image(p, 518) } pub const CLASS_COUNT: i64 = 1000; |