diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-07 19:22:45 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-07 19:22:45 +0100 |
commit | 7396b8ed1a5394c58fcc772e5f6e6038577505b8 (patch) | |
tree | f7ce0cf676705e800093c05884ce5fc7443b7b0b /candle-examples/examples/segment-anything/main.rs | |
parent | 7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a (diff) | |
download | candle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.tar.gz candle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.tar.bz2 candle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.zip |
Segment Anything - process images (#766)
* Start processing images.
* Add LayerNorm2d.
* Properly use LayerNorm2d.
* Tweak eps.
* Use LayerNorm on inputs with a rank different from 3.
* Window partitioning.
* Fix a couple todos.
* More todos.
* Hard-code the einsums.
* More padding support.
* Some sizes tweaks.
* Use the hub to get the weights.
* Use a batch matmul.
* Tweaks.
* More fixes.
* Get some predictions to be generated.
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 106 |
1 files changed, 63 insertions, 43 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 368b5a33..a2722270 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -15,7 +15,7 @@ pub mod model_sam; pub mod model_transformer; use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{Linear, Module, VarBuilder}; use clap::Parser; pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { @@ -27,64 +27,73 @@ pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Resu } #[derive(Debug)] +pub struct LayerNorm2d { + weight: Tensor, + bias: Tensor, + num_channels: usize, + eps: f64, +} + +impl LayerNorm2d { + pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(num_channels, "weight")?; + let bias = vb.get(num_channels, "bias")?; + Ok(Self { + weight, + bias, + num_channels, + eps, + }) + } +} + +impl Module for LayerNorm2d { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let u = xs.mean_keepdim(1)?; + let xs = xs.broadcast_sub(&u)?; + let s = xs.sqr()?.mean_keepdim(1)?; + let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; + xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? + .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) + } +} + +#[derive(Debug)] pub struct MlpBlock { lin1: Linear, lin2: Linear, + activation: candle_nn::Activation, } impl MlpBlock { - pub fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result<Self> { + pub fn new( + embedding_dim: usize, + mlp_dim: usize, + activation: candle_nn::Activation, + vb: VarBuilder, + ) -> Result<Self> { let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?; let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?; - Ok(Self { lin1, lin2 }) + Ok(Self { + lin1, + lin2, + activation, + }) } } impl Module for MlpBlock { fn forward(&self, xs: &Tensor) -> Result<Tensor> { - xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) + xs.apply(&self.lin1)? + .apply(&self.activation)? + .apply(&self.lin2) } } -/* - fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> { - let npatch = xs.dim(1)? - 1; - let n = self.pos_embed.dim(1)? - 1; - let sqrt_n = (n as f64).sqrt(); - if npatch == n && w == h { - return Ok(xs.clone()); - } - let class_pos_embed = self.pos_embed.i((.., ..1))?; - let patch_pos_embed = self.pos_embed.i((.., 1..))?; - let dim = xs.dim(D::Minus1)?; - let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); - let patch_pos_embed = patch_pos_embed - .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? - .transpose(2, 3)? - .transpose(1, 2)?; - // This uses bicubic interpolation in the original implementation. - let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; - let el_count = patch_pos_embed.shape().elem_count(); - let patch_pos_embed = - patch_pos_embed - .transpose(1, 2)? - .transpose(2, 3)? - .reshape((1, el_count / dim, dim))?; - Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1) - } - - fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> { - let (_b, _nc, w, h) = xs.dims4()?; - let xs = self.patch_embed.forward(xs)?; - let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; - &xs + &self.interpolate_pos_encoding(&xs, w, h)? - } -*/ - #[derive(Parser)] struct Args { #[arg(long)] - model: String, + model: Option<String>, #[arg(long)] image: String, @@ -99,13 +108,24 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device); + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); - let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? }; + let model = match args.model { + Some(model) => std::path::PathBuf::from(model), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-sam".to_string()); + api.get("sam_vit_b_01ec64.safetensors")? + } + }; + let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); - let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b + let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b + let (mask, iou_predictions) = sam.forward(&image, false)?; + println!("mask: {mask:?}"); + println!("iou_predictions: {iou_predictions:?}"); Ok(()) } |