diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 05:53:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 05:53:08 +0100 |
commit | 3898e500debd632b520054ebfa42f8333323a20e (patch) | |
tree | c13606f2430abab08e2bfc4c24ea89b2594e485b /candle-examples/examples/segment-anything/main.rs | |
parent | 79c27fc489f2eece486fa433a0ae75c66a398e6f (diff) | |
download | candle-3898e500debd632b520054ebfa42f8333323a20e.tar.gz candle-3898e500debd632b520054ebfa42f8333323a20e.tar.bz2 candle-3898e500debd632b520054ebfa42f8333323a20e.zip |
Generate a mask image + the scaled input image. (#769)
* Also round-trip the original image.
* Make it possible to use a safetensors input.
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 03ebe346..89d5b56c 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -108,8 +108,20 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = - candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?; + let image = if args.image.ends_with(".safetensors") { + let mut tensors = candle::safetensors::load(&args.image, &device)?; + match tensors.remove("image") { + Some(image) => image, + None => { + if tensors.len() != 1 { + anyhow::bail!("multiple tensors in '{}'", args.image) + } + tensors.into_values().next().unwrap() + } + } + } else { + candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)? + }; println!("loaded image {image:?}"); let model = match args.model { @@ -128,5 +140,16 @@ pub fn main() -> anyhow::Result<()> { let (mask, iou_predictions) = sam.forward(&image, false)?; println!("mask:\n{mask}"); println!("iou_predictions: {iou_predictions:?}"); + + // Save the mask as an image. + let mask = mask.ge(&mask.zeros_like()?)?; + let mask = (mask * 255.)?.squeeze(0)?; + let (_one, h, w) = mask.dims3()?; + let mask = mask.expand((3, h, w))?; + candle_examples::save_image(&mask, "sam_mask.png")?; + + let image = sam.preprocess(&image)?; + let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?; + candle_examples::save_image(&image, "sam_input_scaled.png")?; Ok(()) } |