summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-07 19:22:45 +0100
committerGitHub <noreply@github.com>2023-09-07 19:22:45 +0100
commit7396b8ed1a5394c58fcc772e5f6e6038577505b8 (patch)
treef7ce0cf676705e800093c05884ce5fc7443b7b0b /candle-examples/examples/segment-anything/main.rs
parent7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a (diff)
downloadcandle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.tar.gz
candle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.tar.bz2
candle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.zip
Segment Anything - process images (#766)
* Start processing images. * Add LayerNorm2d. * Properly use LayerNorm2d. * Tweak eps. * Use LayerNorm on inputs with a rank different from 3. * Window partitioning. * Fix a couple todos. * More todos. * Hard-code the einsums. * More padding support. * Some sizes tweaks. * Use the hub to get the weights. * Use a batch matmul. * Tweaks. * More fixes. * Get some predictions to be generated.
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r--candle-examples/examples/segment-anything/main.rs106
1 files changed, 63 insertions, 43 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 368b5a33..a2722270 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -15,7 +15,7 @@ pub mod model_sam;
pub mod model_transformer;
use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use candle_nn::{Linear, Module, VarBuilder};
use clap::Parser;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
@@ -27,64 +27,73 @@ pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Resu
}
#[derive(Debug)]
+pub struct LayerNorm2d {
+ weight: Tensor,
+ bias: Tensor,
+ num_channels: usize,
+ eps: f64,
+}
+
+impl LayerNorm2d {
+ pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let weight = vb.get(num_channels, "weight")?;
+ let bias = vb.get(num_channels, "bias")?;
+ Ok(Self {
+ weight,
+ bias,
+ num_channels,
+ eps,
+ })
+ }
+}
+
+impl Module for LayerNorm2d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let u = xs.mean_keepdim(1)?;
+ let xs = xs.broadcast_sub(&u)?;
+ let s = xs.sqr()?.mean_keepdim(1)?;
+ let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
+ xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
+ .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
+ }
+}
+
+#[derive(Debug)]
pub struct MlpBlock {
lin1: Linear,
lin2: Linear,
+ activation: candle_nn::Activation,
}
impl MlpBlock {
- pub fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result<Self> {
+ pub fn new(
+ embedding_dim: usize,
+ mlp_dim: usize,
+ activation: candle_nn::Activation,
+ vb: VarBuilder,
+ ) -> Result<Self> {
let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?;
let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?;
- Ok(Self { lin1, lin2 })
+ Ok(Self {
+ lin1,
+ lin2,
+ activation,
+ })
}
}
impl Module for MlpBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
+ xs.apply(&self.lin1)?
+ .apply(&self.activation)?
+ .apply(&self.lin2)
}
}
-/*
- fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
- let npatch = xs.dim(1)? - 1;
- let n = self.pos_embed.dim(1)? - 1;
- let sqrt_n = (n as f64).sqrt();
- if npatch == n && w == h {
- return Ok(xs.clone());
- }
- let class_pos_embed = self.pos_embed.i((.., ..1))?;
- let patch_pos_embed = self.pos_embed.i((.., 1..))?;
- let dim = xs.dim(D::Minus1)?;
- let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
- let patch_pos_embed = patch_pos_embed
- .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
- .transpose(2, 3)?
- .transpose(1, 2)?;
- // This uses bicubic interpolation in the original implementation.
- let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
- let el_count = patch_pos_embed.shape().elem_count();
- let patch_pos_embed =
- patch_pos_embed
- .transpose(1, 2)?
- .transpose(2, 3)?
- .reshape((1, el_count / dim, dim))?;
- Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
- }
-
- fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
- let (_b, _nc, w, h) = xs.dims4()?;
- let xs = self.patch_embed.forward(xs)?;
- let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
- &xs + &self.interpolate_pos_encoding(&xs, w, h)?
- }
-*/
-
#[derive(Parser)]
struct Args {
#[arg(long)]
- model: String,
+ model: Option<String>,
#[arg(long)]
image: String,
@@ -99,13 +108,24 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
- let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device);
+ let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
- let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
+ let model = match args.model {
+ Some(model) => std::path::PathBuf::from(model),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("lmz/candle-sam".to_string());
+ api.get("sam_vit_b_01ec64.safetensors")?
+ }
+ };
+ let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
- let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
+ let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
+ let (mask, iou_predictions) = sam.forward(&image, false)?;
+ println!("mask: {mask:?}");
+ println!("iou_predictions: {iou_predictions:?}");
Ok(())
}