summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/segment_anything/sam.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-29 22:39:43 +0200
committerGitHub <noreply@github.com>2023-09-29 22:39:43 +0200
commitd188d6a7642c470732f740e93c035fd792c9706c (patch)
tree3433cd3903533f57e3432db3aa2001e856fd1aa5 /candle-transformers/src/models/segment_anything/sam.rs
parent0ac2db577b69a387eee577ee4bb7362a3b3ae691 (diff)
downloadcandle-d188d6a7642c470732f740e93c035fd792c9706c.tar.gz
candle-d188d6a7642c470732f740e93c035fd792c9706c.tar.bz2
candle-d188d6a7642c470732f740e93c035fd792c9706c.zip
Fix the multiple points case for sam. (#998)
Diffstat (limited to 'candle-transformers/src/models/segment_anything/sam.rs')
-rw-r--r--candle-transformers/src/models/segment_anything/sam.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs
index 49e95adb..6de7beb2 100644
--- a/candle-transformers/src/models/segment_anything/sam.rs
+++ b/candle-transformers/src/models/segment_anything/sam.rs
@@ -171,8 +171,8 @@ impl Sam {
[x, y]
})
.collect::<Vec<_>>();
- let points = Tensor::from_vec(xys, (n_points, 1, 2), img_embeddings.device())?;
- let labels = Tensor::ones((n_points, 1), DType::F32, img_embeddings.device())?;
+ let points = Tensor::from_vec(xys, (1, n_points, 2), img_embeddings.device())?;
+ let labels = Tensor::ones((1, n_points), DType::F32, img_embeddings.device())?;
Some((points, labels))
};
let points = points.as_ref().map(|(x, y)| (x, y));