diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_prompt_encoder.rs | 54 |
1 files changed, 38 insertions, 16 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index 7ac4c66d..c6ffffd2 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -1,5 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{Linear, Module, VarBuilder}; #[derive(Debug)] struct PostionEmbeddingRandom { @@ -17,7 +17,7 @@ impl PostionEmbeddingRandom { fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> { let coords = coords.affine(2., -1.)?; - let coords = coords.matmul(&self.positional_encoding_gaussian_matrix)?; + let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?; let coords = (coords * (2. * std::f64::consts::PI))?; Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1) } @@ -25,12 +25,14 @@ impl PostionEmbeddingRandom { fn forward(&self, h: usize, w: usize) -> Result<Tensor> { let device = self.positional_encoding_gaussian_matrix.device(); let grid = Tensor::ones((h, w), DType::F32, device)?; - // TODO: cumsum - let x_embed = (&grid - 0.5)?; - // TODO: cumsum - let y_embed = (&grid - 0.5)?; - let x_embed = (x_embed / w as f64)?; - let y_embed = (y_embed / h as f64)?; + let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?; + let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?; + let x_embed = (x_embed / w as f64)? + .reshape((1, w))? + .broadcast_as((h, w))?; + let y_embed = (y_embed / h as f64)? + .reshape((h, 1))? + .broadcast_as((h, w))?; let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?; self.pe_encoding(&coords)?.permute((2, 0, 1)) } @@ -55,13 +57,14 @@ pub struct PromptEncoder { point_embeddings: Vec<candle_nn::Embedding>, not_a_point_embed: candle_nn::Embedding, mask_downscaling_conv1: candle_nn::Conv2d, - mask_downscaling_ln1: LayerNorm, + mask_downscaling_ln1: crate::LayerNorm2d, mask_downscaling_conv2: candle_nn::Conv2d, - mask_downscaling_ln2: LayerNorm, + mask_downscaling_ln2: crate::LayerNorm2d, mask_downscaling_conv3: candle_nn::Conv2d, no_mask_embed: candle_nn::Embedding, image_embedding_size: (usize, usize), input_image_size: (usize, usize), + embed_dim: usize, } impl PromptEncoder { @@ -97,8 +100,9 @@ impl PromptEncoder { vb.pp("mask_downscaling.6"), )?; let mask_downscaling_ln1 = - layer_norm(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; - let mask_downscaling_ln2 = layer_norm(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; + crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; + let mask_downscaling_ln2 = + crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; let mut point_embeddings = Vec::with_capacity(num_points_embeddings); let vb_e = vb.pp("point_embeddings"); for i in 0..num_points_embeddings { @@ -117,9 +121,16 @@ impl PromptEncoder { no_mask_embed, image_embedding_size, input_image_size, + embed_dim, }) } + pub fn get_dense_pe(&self) -> Result<Tensor> { + self.pe_layer + .forward(self.image_embedding_size.0, self.image_embedding_size.1)? + .unsqueeze(0) + } + fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> { masks .apply(&self.mask_downscaling_conv1)? @@ -133,7 +144,16 @@ impl PromptEncoder { fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> { let points = (points + 0.5)?; - let points = if pad { todo!() } else { points }; + let dev = points.device(); + let (points, labels) = if pad { + let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?; + let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?; + let points = Tensor::cat(&[&points, &padding_point], 1)?; + let labels = Tensor::cat(&[labels, &padding_label], 1)?; + (points, labels) + } else { + (points, labels.clone()) + }; let point_embedding = self .pe_layer .forward_with_coords(&points, self.input_image_size)?; @@ -154,7 +174,7 @@ impl PromptEncoder { Tensor::cat(&[&ce1, &ce2], 1) } - fn forward( + pub fn forward( &self, points: Option<(&Tensor, &Tensor)>, boxes: Option<&Tensor>, @@ -172,7 +192,9 @@ impl PromptEncoder { (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?, (Some(se_points), None) => se_points, (None, Some(se_boxes)) => se_boxes, - (None, None) => Tensor::zeros(1, DType::F32, &candle::Device::Cpu)?, + (None, None) => { + Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)? + } }; let dense_embeddings = match masks { @@ -180,7 +202,7 @@ impl PromptEncoder { let emb = self.no_mask_embed.embeddings(); emb.reshape((1, emb.elem_count(), 1, 1))?.expand(( 1, - 0, + emb.elem_count(), self.image_embedding_size.0, self.image_embedding_size.1, ))? |