summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_prompt_encoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs54
1 files changed, 38 insertions, 16 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index 7ac4c66d..c6ffffd2 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -1,5 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use candle_nn::{Linear, Module, VarBuilder};
#[derive(Debug)]
struct PostionEmbeddingRandom {
@@ -17,7 +17,7 @@ impl PostionEmbeddingRandom {
fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
let coords = coords.affine(2., -1.)?;
- let coords = coords.matmul(&self.positional_encoding_gaussian_matrix)?;
+ let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
let coords = (coords * (2. * std::f64::consts::PI))?;
Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
}
@@ -25,12 +25,14 @@ impl PostionEmbeddingRandom {
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
let device = self.positional_encoding_gaussian_matrix.device();
let grid = Tensor::ones((h, w), DType::F32, device)?;
- // TODO: cumsum
- let x_embed = (&grid - 0.5)?;
- // TODO: cumsum
- let y_embed = (&grid - 0.5)?;
- let x_embed = (x_embed / w as f64)?;
- let y_embed = (y_embed / h as f64)?;
+ let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
+ let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
+ let x_embed = (x_embed / w as f64)?
+ .reshape((1, w))?
+ .broadcast_as((h, w))?;
+ let y_embed = (y_embed / h as f64)?
+ .reshape((h, 1))?
+ .broadcast_as((h, w))?;
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
self.pe_encoding(&coords)?.permute((2, 0, 1))
}
@@ -55,13 +57,14 @@ pub struct PromptEncoder {
point_embeddings: Vec<candle_nn::Embedding>,
not_a_point_embed: candle_nn::Embedding,
mask_downscaling_conv1: candle_nn::Conv2d,
- mask_downscaling_ln1: LayerNorm,
+ mask_downscaling_ln1: crate::LayerNorm2d,
mask_downscaling_conv2: candle_nn::Conv2d,
- mask_downscaling_ln2: LayerNorm,
+ mask_downscaling_ln2: crate::LayerNorm2d,
mask_downscaling_conv3: candle_nn::Conv2d,
no_mask_embed: candle_nn::Embedding,
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
+ embed_dim: usize,
}
impl PromptEncoder {
@@ -97,8 +100,9 @@ impl PromptEncoder {
vb.pp("mask_downscaling.6"),
)?;
let mask_downscaling_ln1 =
- layer_norm(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
- let mask_downscaling_ln2 = layer_norm(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
+ crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
+ let mask_downscaling_ln2 =
+ crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
let vb_e = vb.pp("point_embeddings");
for i in 0..num_points_embeddings {
@@ -117,9 +121,16 @@ impl PromptEncoder {
no_mask_embed,
image_embedding_size,
input_image_size,
+ embed_dim,
})
}
+ pub fn get_dense_pe(&self) -> Result<Tensor> {
+ self.pe_layer
+ .forward(self.image_embedding_size.0, self.image_embedding_size.1)?
+ .unsqueeze(0)
+ }
+
fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
masks
.apply(&self.mask_downscaling_conv1)?
@@ -133,7 +144,16 @@ impl PromptEncoder {
fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
let points = (points + 0.5)?;
- let points = if pad { todo!() } else { points };
+ let dev = points.device();
+ let (points, labels) = if pad {
+ let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
+ let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
+ let points = Tensor::cat(&[&points, &padding_point], 1)?;
+ let labels = Tensor::cat(&[labels, &padding_label], 1)?;
+ (points, labels)
+ } else {
+ (points, labels.clone())
+ };
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
@@ -154,7 +174,7 @@ impl PromptEncoder {
Tensor::cat(&[&ce1, &ce2], 1)
}
- fn forward(
+ pub fn forward(
&self,
points: Option<(&Tensor, &Tensor)>,
boxes: Option<&Tensor>,
@@ -172,7 +192,9 @@ impl PromptEncoder {
(Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
(Some(se_points), None) => se_points,
(None, Some(se_boxes)) => se_boxes,
- (None, None) => Tensor::zeros(1, DType::F32, &candle::Device::Cpu)?,
+ (None, None) => {
+ Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
+ }
};
let dense_embeddings = match masks {
@@ -180,7 +202,7 @@ impl PromptEncoder {
let emb = self.no_mask_embed.embeddings();
emb.reshape((1, emb.elem_count(), 1, 1))?.expand((
1,
- 0,
+ emb.elem_count(),
self.image_embedding_size.0,
self.image_embedding_size.1,
))?