diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 12:26:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 12:26:56 +0100 |
commit | 28c87f6a34e594aca5f558bceebc4c0a9c95911a (patch) | |
tree | 11d702a507de898a7e734aa22349657d04931fb4 /candle-examples/examples/segment-anything/model_mask_decoder.rs | |
parent | c1453f00b11c9dd12c5aa81fb4355ce47d22d477 (diff) | |
download | candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.gz candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.bz2 candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.zip |
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator.
* Add the target point.
* Fix.
* Remove the allow-unused.
* Mask post-processing.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_mask_decoder.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index acbfeeea..598af1f6 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -1,4 +1,4 @@ -use candle::{DType, IndexOp, Result, Tensor, D}; +use candle::{IndexOp, Result, Tensor}; use candle_nn::{Linear, Module, VarBuilder}; use crate::model_transformer::TwoWayTransformer; @@ -188,7 +188,7 @@ impl MaskDecoder { // Expand per-image data in batch direction to be per mask let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?; - let src = (src + dense_prompt_embeddings)?; + let src = src.broadcast_add(dense_prompt_embeddings)?; let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?; let (b, c, h, w) = src.dims4()?; |