diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 20:13:29 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 20:13:29 +0100 |
commit | acf8f10ae17d7f472dc1a634fbd7358a79d7b4d4 (patch) | |
tree | d36b6aa116a95559d9a2d2e40b79e89b6f5537cf /candle-examples/examples/segment-anything/model_prompt_encoder.rs | |
parent | 0906acab9186fbb14a2268e12dd66c13b0877f3e (diff) | |
download | candle-acf8f10ae17d7f472dc1a634fbd7358a79d7b4d4.tar.gz candle-acf8f10ae17d7f472dc1a634fbd7358a79d7b4d4.tar.bz2 candle-acf8f10ae17d7f472dc1a634fbd7358a79d7b4d4.zip |
Get the comparison operation to work on scalar values. (#780)
* Get the comparison operation to work on scalar values.
* Add some time measurement.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_prompt_encoder.rs | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index 40cc6e36..7bbe8419 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -161,21 +161,21 @@ impl PromptEncoder { .forward_with_coords(&points, self.input_image_size)?; let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?; let zeros = point_embedding.zeros_like()?; - let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond( + let point_embedding = labels.lt(0f32)?.where_cond( &self .not_a_point_embed .embeddings() .broadcast_as(zeros.shape())?, &point_embedding, )?; - let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond( + let labels0 = labels.eq(0f32)?.where_cond( &self.point_embeddings[0] .embeddings() .broadcast_as(zeros.shape())?, &zeros, )?; let point_embedding = (point_embedding + labels0)?; - let labels1 = labels.eq(&labels.ones_like()?)?.where_cond( + let labels1 = labels.eq(1f32)?.where_cond( &self.point_embeddings[1] .embeddings() .broadcast_as(zeros.shape())?, |