summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_sam.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 05:53:08 +0100
committerGitHub <noreply@github.com>2023-09-08 05:53:08 +0100
commit3898e500debd632b520054ebfa42f8333323a20e (patch)
treec13606f2430abab08e2bfc4c24ea89b2594e485b /candle-examples/examples/segment-anything/model_sam.rs
parent79c27fc489f2eece486fa433a0ae75c66a398e6f (diff)
downloadcandle-3898e500debd632b520054ebfa42f8333323a20e.tar.gz
candle-3898e500debd632b520054ebfa42f8333323a20e.tar.bz2
candle-3898e500debd632b520054ebfa42f8333323a20e.zip
Generate a mask image + the scaled input image. (#769)
* Also round-trip the original image. * Make it possible to use a safetensors input.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs10
1 files changed, 9 insertions, 1 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
index acba7ef4..237163a3 100644
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ b/candle-examples/examples/segment-anything/model_sam.rs
@@ -87,7 +87,15 @@ impl Sam {
Ok((low_res_mask, iou_predictions))
}
- fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
+ pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
+ let img = img
+ .broadcast_mul(&self.pixel_std)?
+ .broadcast_add(&self.pixel_mean)?;
+ img.maximum(&img.zeros_like()?)?
+ .minimum(&(img.ones_like()? * 255.)?)
+ }
+
+ pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
let (c, h, w) = img.dims3()?;
let img = img
.to_dtype(DType::F32)?