summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_image_encoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_image_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_image_encoder.rs257
1 files changed, 257 insertions, 0 deletions
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs
new file mode 100644
index 00000000..c8b6fd7b
--- /dev/null
+++ b/candle-examples/examples/segment-anything/model_image_encoder.rs
@@ -0,0 +1,257 @@
+use candle::{DType, IndexOp, Result, Tensor, D};
+use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+
+#[derive(Debug)]
+struct PatchEmbed {
+ proj: candle_nn::Conv2d,
+}
+
+impl PatchEmbed {
+ fn new(
+ in_chans: usize,
+ embed_dim: usize,
+ k_size: usize,
+ stride: usize,
+ padding: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ stride,
+ padding,
+ ..Default::default()
+ };
+ let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
+ Ok(Self { proj })
+ }
+}
+
+impl Module for PatchEmbed {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.proj)?.permute((0, 2, 3, 1))
+ }
+}
+
+#[derive(Debug)]
+struct Attention {
+ qkv: Linear,
+ proj: Linear,
+ num_heads: usize,
+ scale: f64,
+ use_rel_pos: bool,
+ rel_pos_hw: Option<(Tensor, Tensor)>,
+}
+
+impl Attention {
+ fn new(
+ dim: usize,
+ num_heads: usize,
+ qkv_bias: bool,
+ use_rel_pos: bool,
+ window_size: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
+ let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
+ let head_dim = dim / num_heads;
+ let scale = 1. / (head_dim as f64).sqrt();
+ let rel_pos_hw = if use_rel_pos {
+ let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?;
+ let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?;
+ Some((h, w))
+ } else {
+ None
+ };
+ Ok(Self {
+ qkv,
+ proj,
+ num_heads,
+ scale,
+ use_rel_pos,
+ rel_pos_hw,
+ })
+ }
+}
+
+impl Module for Attention {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b, h, w, c) = xs.dims4()?;
+ let qkv = self
+ .qkv
+ .forward(xs)?
+ .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
+ .permute((2, 0, 3, 1, 4))?
+ .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
+ let q = qkv.i(0)?;
+ let k = qkv.i(1)?;
+ let v = qkv.i(2)?;
+ let attn = (q * self.scale)?.matmul(&k.t()?)?;
+ if self.use_rel_pos {
+ todo!()
+ }
+ let attn = candle_nn::ops::softmax_last_dim(&attn)?;
+ let attn = attn
+ .matmul(&v)?
+ .reshape((b, self.num_heads, h, w, c / self.num_heads))?
+ .permute((0, 2, 3, 1, 4))?
+ .reshape((b, h, w, c / self.num_heads))?;
+ self.proj.forward(&attn)
+ }
+}
+
+#[derive(Debug)]
+struct Block {
+ norm1: LayerNorm,
+ attn: Attention,
+ norm2: LayerNorm,
+ mlp: crate::MlpBlock,
+ window_size: usize,
+}
+
+impl Block {
+ fn new(
+ dim: usize,
+ num_heads: usize,
+ qkv_bias: bool,
+ use_rel_pos: bool,
+ window_size: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
+ let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
+ let attn = Attention::new(
+ dim,
+ num_heads,
+ qkv_bias,
+ use_rel_pos,
+ window_size,
+ vb.pp("attn"),
+ )?;
+ let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
+ Ok(Self {
+ norm1,
+ attn,
+ norm2,
+ mlp,
+ window_size,
+ })
+ }
+}
+
+impl Module for Block {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let shortcut = xs;
+ let xs = self.norm1.forward(xs)?;
+ if self.window_size > 0 {
+ todo!()
+ }
+ let xs = self.attn.forward(&xs)?;
+ if self.window_size > 0 {
+ todo!()
+ }
+ let xs = (xs + shortcut)?;
+ &xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
+ }
+}
+
+#[derive(Debug)]
+struct ImageEncoderViT {
+ img_size: usize,
+ patch_embed: PatchEmbed,
+ blocks: Vec<Block>,
+ neck_conv1: candle_nn::Conv2d,
+ neck_ln1: LayerNorm,
+ neck_conv2: candle_nn::Conv2d,
+ neck_ln2: LayerNorm,
+ pos_embed: Option<Tensor>,
+}
+
+impl ImageEncoderViT {
+ #[allow(clippy::too_many_arguments)]
+ fn new(
+ img_size: usize,
+ patch_size: usize,
+ in_chans: usize,
+ embed_dim: usize,
+ depth: usize,
+ num_heads: usize,
+ out_chans: usize,
+ qkv_bias: bool,
+ use_rel_pos: bool,
+ use_abs_pos: bool,
+ window_size: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let patch_embed = PatchEmbed::new(
+ in_chans,
+ embed_dim,
+ patch_size,
+ patch_size,
+ 0,
+ vb.pp("patch_embed"),
+ )?;
+ let mut blocks = Vec::with_capacity(depth);
+ let vb_b = vb.pp("blocks");
+ for i in 0..depth {
+ let block = Block::new(
+ embed_dim,
+ num_heads,
+ qkv_bias,
+ use_rel_pos,
+ window_size,
+ vb_b.pp(i),
+ )?;
+ blocks.push(block)
+ }
+ let neck_conv1 = candle_nn::conv2d_no_bias(
+ embed_dim,
+ out_chans,
+ 1,
+ Default::default(),
+ vb.pp("neck.0"),
+ )?;
+ let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
+ let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?;
+ let pos_embed = if use_abs_pos {
+ let p = vb.get(
+ (1, img_size / patch_size, img_size / patch_size, embed_dim),
+ "pos_embed",
+ )?;
+ Some(p)
+ } else {
+ None
+ };
+ Ok(Self {
+ img_size,
+ patch_embed,
+ blocks,
+ neck_conv1,
+ neck_ln1,
+ neck_conv2,
+ neck_ln2,
+ pos_embed,
+ })
+ }
+}
+
+impl Module for ImageEncoderViT {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.patch_embed.forward(xs)?;
+ let mut xs = match &self.pos_embed {
+ Some(pos_embed) => (xs + pos_embed)?,
+ None => xs,
+ };
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ xs.permute((0, 3, 1, 2))?
+ .apply(&self.neck_conv1)?
+ .apply(&self.neck_ln1)?
+ .apply(&self.neck_conv2)?
+ .apply(&self.neck_ln2)
+ }
+}