summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_mask_decoder.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 12:26:56 +0100
committerGitHub <noreply@github.com>2023-09-08 12:26:56 +0100
commit28c87f6a34e594aca5f558bceebc4c0a9c95911a (patch)
tree11d702a507de898a7e734aa22349657d04931fb4 /candle-examples/examples/segment-anything/model_mask_decoder.rs
parentc1453f00b11c9dd12c5aa81fb4355ce47d22d477 (diff)
downloadcandle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.gz
candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.bz2
candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.zip
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator. * Add the target point. * Fix. * Remove the allow-unused. * Mask post-processing.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
index acbfeeea..598af1f6 100644
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs
@@ -1,4 +1,4 @@
-use candle::{DType, IndexOp, Result, Tensor, D};
+use candle::{IndexOp, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
@@ -188,7 +188,7 @@ impl MaskDecoder {
// Expand per-image data in batch direction to be per mask
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
- let src = (src + dense_prompt_embeddings)?;
+ let src = src.broadcast_add(dense_prompt_embeddings)?;
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
let (b, c, h, w) = src.dims4()?;