summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 09:39:10 +0100
committerGitHub <noreply@github.com>2023-09-08 09:39:10 +0100
commitc1453f00b11c9dd12c5aa81fb4355ce47d22d477 (patch)
tree70fac003f636d26db3b1df5ef56980500700c0f1 /candle-examples/examples/segment-anything
parent989a4807b151f08c651b5027cc1b547a59adf966 (diff)
downloadcandle-c1453f00b11c9dd12c5aa81fb4355ce47d22d477.tar.gz
candle-c1453f00b11c9dd12c5aa81fb4355ce47d22d477.tar.bz2
candle-c1453f00b11c9dd12c5aa81fb4355ce47d22d477.zip
Improve the safetensor loading in the segment-anything example. (#772)
* Improve the safetensor loading in the segment-anything example. * Properly handle the labels when embedding the point prompts.
Diffstat (limited to 'candle-examples/examples/segment-anything')
-rw-r--r--candle-examples/examples/segment-anything/main.rs7
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs23
2 files changed, 28 insertions, 2 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 89d5b56c..c53c1010 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -110,7 +110,7 @@ pub fn main() -> anyhow::Result<()> {
let image = if args.image.ends_with(".safetensors") {
let mut tensors = candle::safetensors::load(&args.image, &device)?;
- match tensors.remove("image") {
+ let image = match tensors.remove("image") {
Some(image) => image,
None => {
if tensors.len() != 1 {
@@ -118,6 +118,11 @@ pub fn main() -> anyhow::Result<()> {
}
tensors.into_values().next().unwrap()
}
+ };
+ if image.rank() == 4 {
+ image.get(0)?
+ } else {
+ image
}
} else {
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index aab0c4fd..e4291ebb 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -157,7 +157,28 @@ impl PromptEncoder {
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
- // TODO: tweak based on labels.
+ let zeros = point_embedding.zeros_like()?;
+ let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
+ &self
+ .not_a_point_embed
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &point_embedding,
+ )?;
+ let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
+ &self.point_embeddings[0]
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &zeros,
+ )?;
+ let point_embedding = (point_embedding + labels0)?;
+ let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
+ &self.point_embeddings[1]
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &zeros,
+ )?;
+ let point_embedding = (point_embedding + labels1)?;
Ok(point_embedding)
}