//! Segment Anything Model (SAM) //! //! SAM is an architecture for image segmentation, capable of segmenting any object //! in an image based on prompts like points or boxes. //! This model provides a robust and fast image segmentation pipeline that can be tweaked via //! some prompting (requesting some points to be in the target mask, requesting some //! points to be part of the background so _not_ in the target mask, specifying some //! bounding box). //! //! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/candle-segment-anything-wasm) //! - 💻 [GH Link](https://github.com/facebookresearch/segment-anything) //! - 📝 [Paper](https://arxiv.org/abs/2304.02643) //! - 💡 The default backbone can be replaced by the smaller and faster TinyViT model based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM). //! //! //! ## Example //! //! ```bash //! cargo run --example segment-anything --release -- \ //! --image candle-examples/examples/yolo-v8/assets/bike.jpg //! --use-tiny --point 0.6,0.6 --point 0.6,0.55 //! ``` //! //!
//! //! //! //!
//! //! //! > Original; Prompt with `--point 0.6,0.55`; Prompt with `--point 0.6,0.6 --point 0.6,0.55` //! pub use crate::models::with_tracing::Linear; use candle::{Result, Tensor}; use candle_nn::{Module, VarBuilder}; pub mod image_encoder; pub mod mask_decoder; pub mod prompt_encoder; pub mod sam; pub mod tiny_vit; pub mod transformer; pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result { if bias { crate::models::with_tracing::linear(in_dim, out_dim, vb) } else { crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb) } } #[derive(Debug)] pub struct LayerNorm2d { weight: Tensor, bias: Tensor, num_channels: usize, eps: f64, } impl LayerNorm2d { pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result { let weight = vb.get(num_channels, "weight")?; let bias = vb.get(num_channels, "bias")?; Ok(Self { weight, bias, num_channels, eps, }) } } impl Module for LayerNorm2d { fn forward(&self, xs: &Tensor) -> Result { let u = xs.mean_keepdim(1)?; let xs = xs.broadcast_sub(&u)?; let s = xs.sqr()?.mean_keepdim(1)?; let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) } } #[derive(Debug)] pub struct MlpBlock { lin1: Linear, lin2: Linear, activation: candle_nn::Activation, span: tracing::Span, } impl MlpBlock { pub fn new( embedding_dim: usize, mlp_dim: usize, activation: candle_nn::Activation, vb: VarBuilder, ) -> Result { let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); Ok(Self { lin1, lin2, activation, span, }) } } impl Module for MlpBlock { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); xs.apply(&self.lin1)? .apply(&self.activation)? .apply(&self.lin2) } }