summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_prompt_encoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index e4291ebb..b401a900 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -1,5 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{Linear, Module, VarBuilder};
+use candle_nn::VarBuilder;
#[derive(Debug)]
struct PostionEmbeddingRandom {
@@ -24,7 +24,6 @@ impl PostionEmbeddingRandom {
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
let device = self.positional_encoding_gaussian_matrix.device();
- let grid = Tensor::ones((h, w), DType::F32, device)?;
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let x_embed = (x_embed / w as f64)?
@@ -157,8 +156,9 @@ impl PromptEncoder {
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
+ let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
let zeros = point_embedding.zeros_like()?;
- let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
+ let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
&self
.not_a_point_embed
.embeddings()