summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_prompt_encoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs192
1 files changed, 192 insertions, 0 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
new file mode 100644
index 00000000..7ac4c66d
--- /dev/null
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -0,0 +1,192 @@
+use candle::{DType, IndexOp, Result, Tensor, D};
+use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+
+#[derive(Debug)]
+struct PostionEmbeddingRandom {
+ positional_encoding_gaussian_matrix: Tensor,
+}
+
+impl PostionEmbeddingRandom {
+ fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
+ let positional_encoding_gaussian_matrix =
+ vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
+ Ok(Self {
+ positional_encoding_gaussian_matrix,
+ })
+ }
+
+ fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
+ let coords = coords.affine(2., -1.)?;
+ let coords = coords.matmul(&self.positional_encoding_gaussian_matrix)?;
+ let coords = (coords * (2. * std::f64::consts::PI))?;
+ Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
+ }
+
+ fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
+ let device = self.positional_encoding_gaussian_matrix.device();
+ let grid = Tensor::ones((h, w), DType::F32, device)?;
+ // TODO: cumsum
+ let x_embed = (&grid - 0.5)?;
+ // TODO: cumsum
+ let y_embed = (&grid - 0.5)?;
+ let x_embed = (x_embed / w as f64)?;
+ let y_embed = (y_embed / h as f64)?;
+ let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
+ self.pe_encoding(&coords)?.permute((2, 0, 1))
+ }
+
+ fn forward_with_coords(
+ &self,
+ coords_input: &Tensor,
+ image_size: (usize, usize),
+ ) -> Result<Tensor> {
+ let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
+ let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
+ let c = coords_input.dim(D::Minus1)?;
+ let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
+ let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
+ self.pe_encoding(&coords)
+ }
+}
+
+#[derive(Debug)]
+pub struct PromptEncoder {
+ pe_layer: PostionEmbeddingRandom,
+ point_embeddings: Vec<candle_nn::Embedding>,
+ not_a_point_embed: candle_nn::Embedding,
+ mask_downscaling_conv1: candle_nn::Conv2d,
+ mask_downscaling_ln1: LayerNorm,
+ mask_downscaling_conv2: candle_nn::Conv2d,
+ mask_downscaling_ln2: LayerNorm,
+ mask_downscaling_conv3: candle_nn::Conv2d,
+ no_mask_embed: candle_nn::Embedding,
+ image_embedding_size: (usize, usize),
+ input_image_size: (usize, usize),
+}
+
+impl PromptEncoder {
+ pub fn new(
+ embed_dim: usize,
+ image_embedding_size: (usize, usize),
+ input_image_size: (usize, usize),
+ mask_in_chans: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let num_points_embeddings = 4;
+ let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
+ let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
+ let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let mask_downscaling_conv1 =
+ candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
+ let mask_downscaling_conv2 = candle_nn::conv2d(
+ mask_in_chans / 4,
+ mask_in_chans,
+ 2,
+ cfg,
+ vb.pp("mask_downscaling.3"),
+ )?;
+ let mask_downscaling_conv3 = candle_nn::conv2d(
+ mask_in_chans,
+ embed_dim,
+ 1,
+ Default::default(),
+ vb.pp("mask_downscaling.6"),
+ )?;
+ let mask_downscaling_ln1 =
+ layer_norm(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
+ let mask_downscaling_ln2 = layer_norm(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
+ let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
+ let vb_e = vb.pp("point_embeddings");
+ for i in 0..num_points_embeddings {
+ let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
+ point_embeddings.push(emb)
+ }
+ Ok(Self {
+ pe_layer,
+ point_embeddings,
+ not_a_point_embed,
+ mask_downscaling_conv1,
+ mask_downscaling_ln1,
+ mask_downscaling_conv2,
+ mask_downscaling_ln2,
+ mask_downscaling_conv3,
+ no_mask_embed,
+ image_embedding_size,
+ input_image_size,
+ })
+ }
+
+ fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
+ masks
+ .apply(&self.mask_downscaling_conv1)?
+ .apply(&self.mask_downscaling_ln1)?
+ .gelu()?
+ .apply(&self.mask_downscaling_conv2)?
+ .apply(&self.mask_downscaling_ln2)?
+ .gelu()?
+ .apply(&self.mask_downscaling_conv3)
+ }
+
+ fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
+ let points = (points + 0.5)?;
+ let points = if pad { todo!() } else { points };
+ let point_embedding = self
+ .pe_layer
+ .forward_with_coords(&points, self.input_image_size)?;
+ // TODO: tweak based on labels.
+ Ok(point_embedding)
+ }
+
+ fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
+ let boxes = (boxes + 0.5)?;
+ let coords = boxes.reshape((boxes.elem_count() / 4, 2, 2))?;
+ let corner_embedding = self
+ .pe_layer
+ .forward_with_coords(&coords, self.input_image_size)?;
+ let ce1 = corner_embedding.i((.., 0))?;
+ let ce2 = corner_embedding.i((.., 1))?;
+ let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
+ let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
+ Tensor::cat(&[&ce1, &ce2], 1)
+ }
+
+ fn forward(
+ &self,
+ points: Option<(&Tensor, &Tensor)>,
+ boxes: Option<&Tensor>,
+ masks: Option<&Tensor>,
+ ) -> Result<(Tensor, Tensor)> {
+ let se_points = match points {
+ Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
+ None => None,
+ };
+ let se_boxes = match boxes {
+ Some(boxes) => Some(self.embed_boxes(boxes)?),
+ None => None,
+ };
+ let sparse_embeddings = match (se_points, se_boxes) {
+ (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
+ (Some(se_points), None) => se_points,
+ (None, Some(se_boxes)) => se_boxes,
+ (None, None) => Tensor::zeros(1, DType::F32, &candle::Device::Cpu)?,
+ };
+
+ let dense_embeddings = match masks {
+ None => {
+ let emb = self.no_mask_embed.embeddings();
+ emb.reshape((1, emb.elem_count(), 1, 1))?.expand((
+ 1,
+ 0,
+ self.image_embedding_size.0,
+ self.image_embedding_size.1,
+ ))?
+ }
+ Some(masks) => self.embed_masks(masks)?,
+ };
+ Ok((sparse_embeddings, dense_embeddings))
+ }
+}