diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_prompt_encoder.rs | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index 40cc6e36..7bbe8419 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -161,21 +161,21 @@ impl PromptEncoder { .forward_with_coords(&points, self.input_image_size)?; let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?; let zeros = point_embedding.zeros_like()?; - let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond( + let point_embedding = labels.lt(0f32)?.where_cond( &self .not_a_point_embed .embeddings() .broadcast_as(zeros.shape())?, &point_embedding, )?; - let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond( + let labels0 = labels.eq(0f32)?.where_cond( &self.point_embeddings[0] .embeddings() .broadcast_as(zeros.shape())?, &zeros, )?; let point_embedding = (point_embedding + labels0)?; - let labels1 = labels.eq(&labels.ones_like()?)?.where_cond( + let labels1 = labels.eq(1f32)?.where_cond( &self.point_embeddings[1] .embeddings() .broadcast_as(zeros.shape())?, |