diff options
Diffstat (limited to 'candle-examples/examples/segment-anything')
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 7 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_prompt_encoder.rs | 6 |
2 files changed, 9 insertions, 4 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 4627248c..ce8e3bb4 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -209,12 +209,17 @@ pub fn main() -> anyhow::Result<()> { } } else { let point = Some((args.point_x, args.point_y)); + let start_time = std::time::Instant::now(); let (mask, iou_predictions) = sam.forward(&image, point, false)?; + println!( + "mask generated in {:.2}s", + start_time.elapsed().as_secs_f32() + ); println!("mask:\n{mask}"); println!("iou_predictions: {iou_predictions:?}"); // Save the mask as an image. - let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?; + let mask = (mask.ge(0f32)? * 255.)?; let (_one, h, w) = mask.dims3()?; let mask = mask.expand((3, h, w))?; candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?; diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index 40cc6e36..7bbe8419 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -161,21 +161,21 @@ impl PromptEncoder { .forward_with_coords(&points, self.input_image_size)?; let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?; let zeros = point_embedding.zeros_like()?; - let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond( + let point_embedding = labels.lt(0f32)?.where_cond( &self .not_a_point_embed .embeddings() .broadcast_as(zeros.shape())?, &point_embedding, )?; - let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond( + let labels0 = labels.eq(0f32)?.where_cond( &self.point_embeddings[0] .embeddings() .broadcast_as(zeros.shape())?, &zeros, )?; let point_embedding = (point_embedding + labels0)?; - let labels1 = labels.eq(&labels.ones_like()?)?.where_cond( + let labels1 = labels.eq(1f32)?.where_cond( &self.point_embeddings[1] .embeddings() .broadcast_as(zeros.shape())?, |