summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_prompt_encoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs23
1 files changed, 22 insertions, 1 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index aab0c4fd..e4291ebb 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -157,7 +157,28 @@ impl PromptEncoder {
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
- // TODO: tweak based on labels.
+ let zeros = point_embedding.zeros_like()?;
+ let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
+ &self
+ .not_a_point_embed
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &point_embedding,
+ )?;
+ let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
+ &self.point_embeddings[0]
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &zeros,
+ )?;
+ let point_embedding = (point_embedding + labels0)?;
+ let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
+ &self.point_embeddings[1]
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &zeros,
+ )?;
+ let point_embedding = (point_embedding + labels1)?;
Ok(point_embedding)
}