diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_image_encoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_image_encoder.rs | 153 |
1 files changed, 132 insertions, 21 deletions
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs index cfcdbb38..f5db2830 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -70,6 +70,60 @@ impl Attention { rel_pos_hw, }) } + + fn add_decomposed_rel_pos( + &self, + attn: Tensor, + q: &Tensor, + (q_h, q_w): (usize, usize), + (k_h, k_w): (usize, usize), + ) -> Result<Tensor> { + match &self.rel_pos_hw { + Some((rel_pos_h, rel_pos_w)) => { + let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?; + let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?; + let (b, _, dim) = q.dims3()?; + let r_q = q.reshape((b, q_h, q_w, dim))?; + // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?; + // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + let rel_w = r_q + .transpose(1, 2)? // -> bwhc + .contiguous()? + .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk + .transpose(1, 2)?; + (attn.reshape((b, q_h, q_w, k_h, k_w))? + + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? + .reshape((b, q_h * q_w, k_h * k_w)) + } + None => Ok(attn), + } + } +} + +fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> { + let max_rel_dist = 2 * usize::max(q_size, k_size) - 1; + let dev = rel_pos.device(); + let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist { + todo!("interpolation") + } else { + rel_pos + }; + let q_coords = Tensor::arange(0u32, q_size as u32, dev)? + .reshape((q_size, 1))? + .to_dtype(DType::F32)?; + let k_coords = Tensor::arange(0u32, k_size as u32, dev)? + .reshape((1, k_size))? + .to_dtype(DType::F32)?; + let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?; + let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?; + let relative_coords = (q_coords.broadcast_sub(&k_coords)? + + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?; + let (d1, d2) = relative_coords.dims2()?; + let relative_coords = relative_coords.to_dtype(DType::U32)?; + rel_pos_resized + .index_select(&relative_coords.reshape(d1 * d2)?, 0)? + .reshape((d1, d2, rel_pos_resized.dim(1)?)) } impl Module for Attention { @@ -77,24 +131,22 @@ impl Module for Attention { let (b, h, w, c) = xs.dims4()?; let qkv = self .qkv - .forward(xs)? + .forward(&xs.flatten_to(1)?)? .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))? .permute((2, 0, 3, 1, 4))? .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?; let q = qkv.i(0)?; let k = qkv.i(1)?; let v = qkv.i(2)?; - let attn = (q * self.scale)?.matmul(&k.t()?)?; - if self.use_rel_pos { - todo!() - } + let attn = (&q * self.scale)?.matmul(&k.t()?)?; + let attn = self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?; let attn = candle_nn::ops::softmax_last_dim(&attn)?; + let attn = attn.matmul(&v)?; let attn = attn - .matmul(&v)? .reshape((b, self.num_heads, h, w, c / self.num_heads))? .permute((0, 2, 3, 1, 4))? - .reshape((b, h, w, c / self.num_heads))?; - self.proj.forward(&attn) + .reshape((b, h * w, c))?; + self.proj.forward(&attn)?.reshape((b, h, w, c)) } } @@ -117,8 +169,8 @@ impl Block { input_size: (usize, usize), vb: VarBuilder, ) -> Result<Self> { - let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; - let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; + let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; let input_size_attn = if window_size == 0 { input_size } else { @@ -132,7 +184,7 @@ impl Block { input_size_attn, vb.pp("attn"), )?; - let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?; + let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; Ok(Self { norm1, attn, @@ -143,17 +195,76 @@ impl Block { } } +fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> { + let (b, h, w, c) = xs.dims4()?; + let pad_h = (window_size - h % window_size) % window_size; + let pad_w = (window_size - w % window_size) % window_size; + let xs = if pad_h > 0 { + xs.pad_with_zeros(1, 0, pad_h)? + } else { + xs + }; + let xs = if pad_w > 0 { + xs.pad_with_zeros(2, 0, pad_w)? + } else { + xs + }; + let (h_p, w_p) = (h + pad_h, w + pad_w); + let windows = xs + .reshape(( + b, + h_p / window_size, + window_size, + w_p / window_size, + window_size, + c, + ))? + .transpose(2, 3)? + .contiguous()? + .flatten_to(2)?; + Ok((windows, (h_p, w_p))) +} + +fn window_unpartition( + windows: Tensor, + window_size: usize, + (h_p, w_p): (usize, usize), + (h, w): (usize, usize), +) -> Result<Tensor> { + let b = windows.dim(0)? / (h_p * w_p / window_size / window_size); + let xs = windows + .reshape(( + b, + h_p / window_size, + w_p / window_size, + window_size, + window_size, + windows.elem_count() / b / h_p / w_p, + ))? + .transpose(2, 3)? + .contiguous()? + .reshape((b, h_p, w_p, windows.elem_count() / b / h_p / w_p))?; + let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs }; + let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs }; + Ok(xs) +} + impl Module for Block { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let shortcut = xs; let xs = self.norm1.forward(xs)?; - if self.window_size > 0 { - todo!() - } + let hw = (xs.dim(1)?, xs.dim(2)?); + let (xs, pad_hw) = if self.window_size > 0 { + window_partition(xs, self.window_size)? + } else { + (xs, (0, 0)) + }; let xs = self.attn.forward(&xs)?; - if self.window_size > 0 { - todo!() - } + let xs = if self.window_size > 0 { + window_unpartition(xs, self.window_size, pad_hw, hw)? + } else { + xs + }; let xs = (xs + shortcut)?; &xs + xs.apply(&self.norm2)?.apply(&self.mlp)? } @@ -165,9 +276,9 @@ pub struct ImageEncoderViT { patch_embed: PatchEmbed, blocks: Vec<Block>, neck_conv1: candle_nn::Conv2d, - neck_ln1: LayerNorm, + neck_ln1: crate::LayerNorm2d, neck_conv2: candle_nn::Conv2d, - neck_ln2: LayerNorm, + neck_ln2: crate::LayerNorm2d, pos_embed: Option<Tensor>, } @@ -222,13 +333,13 @@ impl ImageEncoderViT { Default::default(), vb.pp("neck.0"), )?; - let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?; + let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?; let cfg = candle_nn::Conv2dConfig { padding: 1, ..Default::default() }; let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; - let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?; + let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?; let pos_embed = if use_abs_pos { let p = vb.get( (1, img_size / patch_size, img_size / patch_size, embed_dim), |