summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_sam.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-07 21:45:16 +0100
committerGitHub <noreply@github.com>2023-09-07 21:45:16 +0100
commit79c27fc489f2eece486fa433a0ae75c66a398e6f (patch)
treec86b89b2c1270cf3207c5242de27c6b069652044 /candle-examples/examples/segment-anything/model_sam.rs
parent7396b8ed1a5394c58fcc772e5f6e6038577505b8 (diff)
downloadcandle-79c27fc489f2eece486fa433a0ae75c66a398e6f.tar.gz
candle-79c27fc489f2eece486fa433a0ae75c66a398e6f.tar.bz2
candle-79c27fc489f2eece486fa433a0ae75c66a398e6f.zip
Segment-anything fixes: avoid normalizing twice. (#767)
* Segment-anything fixes: avoid normalizing twice. * More fixes for the image aspect ratio.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs3
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
index 1c8e9a59..acba7ef4 100644
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ b/candle-examples/examples/segment-anything/model_sam.rs
@@ -6,7 +6,7 @@ use crate::model_mask_decoder::MaskDecoder;
use crate::model_prompt_encoder::PromptEncoder;
const PROMPT_EMBED_DIM: usize = 256;
-const IMAGE_SIZE: usize = 1024;
+pub const IMAGE_SIZE: usize = 1024;
const VIT_PATCH_SIZE: usize = 16;
#[derive(Debug)]
@@ -90,6 +90,7 @@ impl Sam {
fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
let (c, h, w) = img.dims3()?;
let img = img
+ .to_dtype(DType::F32)?
.broadcast_sub(&self.pixel_mean)?
.broadcast_div(&self.pixel_std)?;
if h > IMAGE_SIZE || w > IMAGE_SIZE {