summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything')
-rw-r--r--candle-examples/examples/segment-anything/main.rs5
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs3
2 files changed, 5 insertions, 3 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index a2722270..03ebe346 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -108,7 +108,8 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
- let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
+ let image =
+ candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?;
println!("loaded image {image:?}");
let model = match args.model {
@@ -125,7 +126,7 @@ pub fn main() -> anyhow::Result<()> {
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
let (mask, iou_predictions) = sam.forward(&image, false)?;
- println!("mask: {mask:?}");
+ println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
Ok(())
}
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
index 1c8e9a59..acba7ef4 100644
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ b/candle-examples/examples/segment-anything/model_sam.rs
@@ -6,7 +6,7 @@ use crate::model_mask_decoder::MaskDecoder;
use crate::model_prompt_encoder::PromptEncoder;
const PROMPT_EMBED_DIM: usize = 256;
-const IMAGE_SIZE: usize = 1024;
+pub const IMAGE_SIZE: usize = 1024;
const VIT_PATCH_SIZE: usize = 16;
#[derive(Debug)]
@@ -90,6 +90,7 @@ impl Sam {
fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
let (c, h, w) = img.dims3()?;
let img = img
+ .to_dtype(DType::F32)?
.broadcast_sub(&self.pixel_mean)?
.broadcast_div(&self.pixel_std)?;
if h > IMAGE_SIZE || w > IMAGE_SIZE {