summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/segment-anything/main.rs27
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs10
2 files changed, 34 insertions, 3 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 03ebe346..89d5b56c 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -108,8 +108,20 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
- let image =
- candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?;
+ let image = if args.image.ends_with(".safetensors") {
+ let mut tensors = candle::safetensors::load(&args.image, &device)?;
+ match tensors.remove("image") {
+ Some(image) => image,
+ None => {
+ if tensors.len() != 1 {
+ anyhow::bail!("multiple tensors in '{}'", args.image)
+ }
+ tensors.into_values().next().unwrap()
+ }
+ }
+ } else {
+ candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
+ };
println!("loaded image {image:?}");
let model = match args.model {
@@ -128,5 +140,16 @@ pub fn main() -> anyhow::Result<()> {
let (mask, iou_predictions) = sam.forward(&image, false)?;
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
+
+ // Save the mask as an image.
+ let mask = mask.ge(&mask.zeros_like()?)?;
+ let mask = (mask * 255.)?.squeeze(0)?;
+ let (_one, h, w) = mask.dims3()?;
+ let mask = mask.expand((3, h, w))?;
+ candle_examples::save_image(&mask, "sam_mask.png")?;
+
+ let image = sam.preprocess(&image)?;
+ let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
+ candle_examples::save_image(&image, "sam_input_scaled.png")?;
Ok(())
}
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
index acba7ef4..237163a3 100644
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ b/candle-examples/examples/segment-anything/model_sam.rs
@@ -87,7 +87,15 @@ impl Sam {
Ok((low_res_mask, iou_predictions))
}
- fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
+ pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
+ let img = img
+ .broadcast_mul(&self.pixel_std)?
+ .broadcast_add(&self.pixel_mean)?;
+ img.maximum(&img.zeros_like()?)?
+ .minimum(&(img.ones_like()? * 255.)?)
+ }
+
+ pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
let (c, h, w) = img.dims3()?;
let img = img
.to_dtype(DType::F32)?