diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 109 |
1 files changed, 6 insertions, 103 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 9ce2f158..21ba0415 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -7,108 +7,11 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -pub mod model_image_encoder; -pub mod model_mask_decoder; -pub mod model_prompt_encoder; -pub mod model_sam; -pub mod model_tiny_vit; -pub mod model_transformer; - -use candle::{DType, Result, Tensor}; -use candle_nn::{Module, VarBuilder}; +use candle::DType; +use candle_nn::VarBuilder; +use candle_transformers::models::segment_anything::sam; use clap::Parser; -pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { - let inner = if bias { - candle_nn::linear(in_dim, out_dim, vb)? - } else { - candle_nn::linear_no_bias(in_dim, out_dim, vb)? - }; - let span = tracing::span!(tracing::Level::TRACE, "linear"); - Ok(Linear { inner, span }) -} - -#[derive(Debug)] -pub struct LayerNorm2d { - weight: Tensor, - bias: Tensor, - num_channels: usize, - eps: f64, -} - -impl LayerNorm2d { - pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> { - let weight = vb.get(num_channels, "weight")?; - let bias = vb.get(num_channels, "bias")?; - Ok(Self { - weight, - bias, - num_channels, - eps, - }) - } -} - -impl Module for LayerNorm2d { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let u = xs.mean_keepdim(1)?; - let xs = xs.broadcast_sub(&u)?; - let s = xs.sqr()?.mean_keepdim(1)?; - let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; - xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? - .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) - } -} - -#[derive(Debug)] -pub struct MlpBlock { - lin1: Linear, - lin2: Linear, - activation: candle_nn::Activation, - span: tracing::Span, -} - -impl MlpBlock { - pub fn new( - embedding_dim: usize, - mlp_dim: usize, - activation: candle_nn::Activation, - vb: VarBuilder, - ) -> Result<Self> { - let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; - let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; - let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); - Ok(Self { - lin1, - lin2, - activation, - span, - }) - } -} - -impl Module for MlpBlock { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - xs.apply(&self.lin1)? - .apply(&self.activation)? - .apply(&self.lin2) - } -} - -#[derive(Debug)] -pub struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Parser)] struct Args { #[arg(long)] @@ -173,7 +76,7 @@ pub fn main() -> anyhow::Result<()> { let (_c, h, w) = image.dims3()?; (image, h, w) } else { - let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?; + let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?; (image.to_device(&device)?, h, w) }; println!("loaded image {image:?}"); @@ -195,9 +98,9 @@ pub fn main() -> anyhow::Result<()> { let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let sam = if args.use_tiny { - model_sam::Sam::new_tiny(vb)? // tiny vit_t + sam::Sam::new_tiny(vb)? // tiny vit_t } else { - model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b + sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b }; if args.generate_masks { |