summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything')
-rw-r--r--candle-examples/examples/segment-anything/main.rs7
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs6
2 files changed, 9 insertions, 4 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 4627248c..ce8e3bb4 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -209,12 +209,17 @@ pub fn main() -> anyhow::Result<()> {
}
} else {
let point = Some((args.point_x, args.point_y));
+ let start_time = std::time::Instant::now();
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
+ println!(
+ "mask generated in {:.2}s",
+ start_time.elapsed().as_secs_f32()
+ );
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
// Save the mask as an image.
- let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?;
+ let mask = (mask.ge(0f32)? * 255.)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index 40cc6e36..7bbe8419 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -161,21 +161,21 @@ impl PromptEncoder {
.forward_with_coords(&points, self.input_image_size)?;
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
let zeros = point_embedding.zeros_like()?;
- let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
+ let point_embedding = labels.lt(0f32)?.where_cond(
&self
.not_a_point_embed
.embeddings()
.broadcast_as(zeros.shape())?,
&point_embedding,
)?;
- let labels0 = labels.eq(&labels.zeros_like()?)?.where_cond(
+ let labels0 = labels.eq(0f32)?.where_cond(
&self.point_embeddings[0]
.embeddings()
.broadcast_as(zeros.shape())?,
&zeros,
)?;
let point_embedding = (point_embedding + labels0)?;
- let labels1 = labels.eq(&labels.ones_like()?)?.where_cond(
+ let labels1 = labels.eq(1f32)?.where_cond(
&self.point_embeddings[1]
.embeddings()
.broadcast_as(zeros.shape())?,