summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_prompt_encoder.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 20:13:29 +0100
committerGitHub <noreply@github.com>2023-09-08 20:13:29 +0100
commitacf8f10ae17d7f472dc1a634fbd7358a79d7b4d4 (patch)
treed36b6aa116a95559d9a2d2e40b79e89b6f5537cf /candle-examples/examples/segment-anything/model_prompt_encoder.rs
parent0906acab9186fbb14a2268e12dd66c13b0877f3e (diff)
downloadcandle-acf8f10ae17d7f472dc1a634fbd7358a79d7b4d4.tar.gz
candle-acf8f10ae17d7f472dc1a634fbd7358a79d7b4d4.tar.bz2
candle-acf8f10ae17d7f472dc1a634fbd7358a79d7b4d4.zip
Get the comparison operation to work on scalar values. (#780)
* Get the comparison operation to work on scalar values. * Add some time measurement.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index 40cc6e36..7bbe8419 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -161,21 +161,21 @@ impl PromptEncoder {
.forward_with_coords(&points, self.input_image_size)?;
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
let zeros = point_embedding.zeros_like()?;
- let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
+ let point_embedding = labels.lt(0f32)?.where_cond(
&self
.not_a_point_embed
.embeddings()
.broadcast_as(zeros.shape())?,
&point_embedding,
)?;
- let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
+ let labels0 = labels.eq(0f32)?.where_cond(
&self.point_embeddings[0]
.embeddings()
.broadcast_as(zeros.shape())?,
&zeros,
)?;
let point_embedding = (point_embedding + labels0)?;
- let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
+ let labels1 = labels.eq(1f32)?.where_cond(
&self.point_embeddings[1]
.embeddings()
.broadcast_as(zeros.shape())?,