summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_prompt_encoder.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 12:26:56 +0100
committerGitHub <noreply@github.com>2023-09-08 12:26:56 +0100
commit28c87f6a34e594aca5f558bceebc4c0a9c95911a (patch)
tree11d702a507de898a7e734aa22349657d04931fb4 /candle-examples/examples/segment-anything/model_prompt_encoder.rs
parentc1453f00b11c9dd12c5aa81fb4355ce47d22d477 (diff)
downloadcandle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.gz
candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.bz2
candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.zip
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator. * Add the target point. * Fix. * Remove the allow-unused. * Mask post-processing.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index e4291ebb..b401a900 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -1,5 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{Linear, Module, VarBuilder};
+use candle_nn::VarBuilder;
#[derive(Debug)]
struct PostionEmbeddingRandom {
@@ -24,7 +24,6 @@ impl PostionEmbeddingRandom {
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
let device = self.positional_encoding_gaussian_matrix.device();
- let grid = Tensor::ones((h, w), DType::F32, device)?;
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let x_embed = (x_embed / w as f64)?
@@ -157,8 +156,9 @@ impl PromptEncoder {
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
+ let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
let zeros = point_embedding.zeros_like()?;
- let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
+ let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
&self
.not_a_point_embed
.embeddings()