diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 05:53:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 05:53:08 +0100 |
commit | 3898e500debd632b520054ebfa42f8333323a20e (patch) | |
tree | c13606f2430abab08e2bfc4c24ea89b2594e485b /candle-examples/examples/segment-anything/model_sam.rs | |
parent | 79c27fc489f2eece486fa433a0ae75c66a398e6f (diff) | |
download | candle-3898e500debd632b520054ebfa42f8333323a20e.tar.gz candle-3898e500debd632b520054ebfa42f8333323a20e.tar.bz2 candle-3898e500debd632b520054ebfa42f8333323a20e.zip |
Generate a mask image + the scaled input image. (#769)
* Also round-trip the original image.
* Make it possible to use a safetensors input.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_sam.rs | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index acba7ef4..237163a3 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -87,7 +87,15 @@ impl Sam { Ok((low_res_mask, iou_predictions)) } - fn preprocess(&self, img: &Tensor) -> Result<Tensor> { + pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> { + let img = img + .broadcast_mul(&self.pixel_std)? + .broadcast_add(&self.pixel_mean)?; + img.maximum(&img.zeros_like()?)? + .minimum(&(img.ones_like()? * 255.)?) + } + + pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> { let (c, h, w) = img.dims3()?; let img = img .to_dtype(DType::F32)? |