summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r--candle-examples/examples/segment-anything/main.rs109
1 files changed, 6 insertions, 103 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 9ce2f158..21ba0415 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -7,108 +7,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-pub mod model_image_encoder;
-pub mod model_mask_decoder;
-pub mod model_prompt_encoder;
-pub mod model_sam;
-pub mod model_tiny_vit;
-pub mod model_transformer;
-
-use candle::{DType, Result, Tensor};
-use candle_nn::{Module, VarBuilder};
+use candle::DType;
+use candle_nn::VarBuilder;
+use candle_transformers::models::segment_anything::sam;
use clap::Parser;
-pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
- let inner = if bias {
- candle_nn::linear(in_dim, out_dim, vb)?
- } else {
- candle_nn::linear_no_bias(in_dim, out_dim, vb)?
- };
- let span = tracing::span!(tracing::Level::TRACE, "linear");
- Ok(Linear { inner, span })
-}
-
-#[derive(Debug)]
-pub struct LayerNorm2d {
- weight: Tensor,
- bias: Tensor,
- num_channels: usize,
- eps: f64,
-}
-
-impl LayerNorm2d {
- pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
- let weight = vb.get(num_channels, "weight")?;
- let bias = vb.get(num_channels, "bias")?;
- Ok(Self {
- weight,
- bias,
- num_channels,
- eps,
- })
- }
-}
-
-impl Module for LayerNorm2d {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let u = xs.mean_keepdim(1)?;
- let xs = xs.broadcast_sub(&u)?;
- let s = xs.sqr()?.mean_keepdim(1)?;
- let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
- xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
- .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
- }
-}
-
-#[derive(Debug)]
-pub struct MlpBlock {
- lin1: Linear,
- lin2: Linear,
- activation: candle_nn::Activation,
- span: tracing::Span,
-}
-
-impl MlpBlock {
- pub fn new(
- embedding_dim: usize,
- mlp_dim: usize,
- activation: candle_nn::Activation,
- vb: VarBuilder,
- ) -> Result<Self> {
- let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
- let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
- let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
- Ok(Self {
- lin1,
- lin2,
- activation,
- span,
- })
- }
-}
-
-impl Module for MlpBlock {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- xs.apply(&self.lin1)?
- .apply(&self.activation)?
- .apply(&self.lin2)
- }
-}
-
-#[derive(Debug)]
-pub struct Linear {
- inner: candle_nn::Linear,
- span: tracing::Span,
-}
-
-impl Module for Linear {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
#[derive(Parser)]
struct Args {
#[arg(long)]
@@ -173,7 +76,7 @@ pub fn main() -> anyhow::Result<()> {
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
- let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
+ let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@@ -195,9 +98,9 @@ pub fn main() -> anyhow::Result<()> {
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = if args.use_tiny {
- model_sam::Sam::new_tiny(vb)? // tiny vit_t
+ sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
- model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
+ sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
if args.generate_masks {