diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_image_encoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_image_encoder.rs | 483 |
1 files changed, 0 insertions, 483 deletions
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs deleted file mode 100644 index 76cd15d0..00000000 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ /dev/null @@ -1,483 +0,0 @@ -use candle::{DType, IndexOp, Result, Tensor}; -use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder}; - -#[derive(Debug)] -struct PatchEmbed { - proj: candle_nn::Conv2d, - span: tracing::Span, -} - -impl PatchEmbed { - fn new( - in_chans: usize, - embed_dim: usize, - k_size: usize, - stride: usize, - padding: usize, - vb: VarBuilder, - ) -> Result<Self> { - let cfg = candle_nn::Conv2dConfig { - stride, - padding, - ..Default::default() - }; - let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; - let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); - Ok(Self { proj, span }) - } -} - -impl Module for PatchEmbed { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - xs.apply(&self.proj)?.permute((0, 2, 3, 1)) - } -} - -// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final -// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096 -// (attn.reshape((b, q_h, q_w, k_h, k_w))? -// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? -// .reshape((b, q_h * q_w, k_h * k_w)) -// Ideally we would perform this operation in place but this is not supported in candle at the -// moment. We should also investigate using f16 rather than f32. -struct Add3(usize, usize, usize, usize, usize); -impl candle::CustomOp3 for Add3 { - fn name(&self) -> &'static str { - "add3" - } - - fn cpu_fwd( - &self, - s1: &candle::CpuStorage, - l1: &candle::Layout, - s2: &candle::CpuStorage, - l2: &candle::Layout, - s3: &candle::CpuStorage, - l3: &candle::Layout, - ) -> Result<(candle::CpuStorage, candle::Shape)> { - use rayon::prelude::*; - - let Add3(b, q_h, q_w, k_h, k_w) = *self; - let s1 = s1.as_slice::<f32>()?; - let s1 = match l1.contiguous_offsets() { - None => candle::bail!("input1 has to be contiguous"), - Some((o1, o2)) => &s1[o1..o2], - }; - let s2 = s2.as_slice::<f32>()?; - let s2 = match l2.contiguous_offsets() { - None => candle::bail!("input2 has to be contiguous"), - Some((o1, o2)) => &s2[o1..o2], - }; - let s3 = s3.as_slice::<f32>()?; - let s3 = match l3.contiguous_offsets() { - None => candle::bail!("input3 has to be contiguous"), - Some((o1, o2)) => &s3[o1..o2], - }; - let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w]; - dst.par_chunks_exact_mut(k_h * k_w) - .enumerate() - .for_each(|(b_idx, dst)| { - let s1_idx = b_idx * k_h * k_w; - let s2_idx = b_idx * k_h; - let s3_idx = b_idx * k_w; - for h_idx in 0..k_h { - let s1_idx = s1_idx + h_idx * k_w; - let s2_idx = s2_idx + h_idx; - let dst_idx = h_idx * k_w; - for w_idx in 0..k_w { - let s1_idx = s1_idx + w_idx; - let s3_idx = s3_idx + w_idx; - let dst_idx = dst_idx + w_idx; - dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx] - } - } - }); - let dst = candle::WithDType::to_cpu_storage_owned(dst); - Ok((dst, (b, q_h * q_w, k_h * k_w).into())) - } -} - -#[derive(Debug)] -struct Attention { - qkv: crate::Linear, - proj: crate::Linear, - num_heads: usize, - scale: f64, - rel_pos_hw: Option<(Tensor, Tensor)>, - span: tracing::Span, - span_matmul: tracing::Span, - span_rel_pos: tracing::Span, - span_softmax: tracing::Span, -} - -impl Attention { - fn new( - dim: usize, - num_heads: usize, - qkv_bias: bool, - use_rel_pos: bool, - input_size: (usize, usize), - vb: VarBuilder, - ) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "attention"); - let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); - let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos"); - let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); - let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; - let proj = crate::linear(vb.pp("proj"), dim, dim, true)?; - let head_dim = dim / num_heads; - let scale = 1. / (head_dim as f64).sqrt(); - let rel_pos_hw = if use_rel_pos { - let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?; - let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?; - Some((h, w)) - } else { - None - }; - Ok(Self { - qkv, - proj, - num_heads, - scale, - rel_pos_hw, - span, - span_matmul, - span_rel_pos, - span_softmax, - }) - } - - fn add_decomposed_rel_pos( - &self, - attn: Tensor, - q: &Tensor, - (q_h, q_w): (usize, usize), - (k_h, k_w): (usize, usize), - ) -> Result<Tensor> { - match &self.rel_pos_hw { - Some((rel_pos_h, rel_pos_w)) => { - let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?; - let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?; - let (b, _, dim) = q.dims3()?; - let r_q = q.reshape((b, q_h, q_w, dim))?; - // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) - let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?; - // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) - let rel_w = r_q - .transpose(1, 2)? // -> bwhc - .contiguous()? - .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk - .transpose(1, 2)? - .contiguous()?; - if attn.device().is_cpu() { - let op = Add3(b, q_h, q_w, k_h, k_w); - attn.apply_op3_no_bwd(&rel_h, &rel_w, &op) - } else { - (attn.reshape((b, q_h, q_w, k_h, k_w))? - + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? - .reshape((b, q_h * q_w, k_h * k_w)) - } - } - None => Ok(attn), - } - } -} - -fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> { - let max_rel_dist = 2 * usize::max(q_size, k_size) - 1; - let dev = rel_pos.device(); - let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist { - todo!("interpolation") - } else { - rel_pos - }; - let q_coords = Tensor::arange(0u32, q_size as u32, dev)? - .reshape((q_size, 1))? - .to_dtype(DType::F32)?; - let k_coords = Tensor::arange(0u32, k_size as u32, dev)? - .reshape((1, k_size))? - .to_dtype(DType::F32)?; - let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?; - let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?; - let relative_coords = (q_coords.broadcast_sub(&k_coords)? - + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?; - let (d1, d2) = relative_coords.dims2()?; - let relative_coords = relative_coords.to_dtype(DType::U32)?; - rel_pos_resized - .index_select(&relative_coords.reshape(d1 * d2)?, 0)? - .reshape((d1, d2, ())) -} - -impl Module for Attention { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let (b, h, w, c) = xs.dims4()?; - let qkv = self - .qkv - .forward(&xs.flatten_to(1)?)? - .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))? - .permute((2, 0, 3, 1, 4))? - .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?; - let q = qkv.i(0)?; - let k = qkv.i(1)?; - let v = qkv.i(2)?; - let attn = { - let _enter = self.span_matmul.enter(); - (&q * self.scale)?.matmul(&k.t()?)? - }; - let attn = { - let _enter = self.span_rel_pos.enter(); - self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))? - }; - let attn = { - let _enter = self.span_softmax.enter(); - candle_nn::ops::softmax_last_dim(&attn)? - }; - let attn = { - let _enter = self.span_matmul.enter(); - attn.matmul(&v)? - }; - let attn = attn - .reshape((b, self.num_heads, h, w, c / self.num_heads))? - .permute((0, 2, 3, 1, 4))? - .reshape((b, h * w, c))?; - self.proj.forward(&attn)?.reshape((b, h, w, c)) - } -} - -#[derive(Debug)] -struct Block { - norm1: LayerNorm, - attn: Attention, - norm2: LayerNorm, - mlp: crate::MlpBlock, - window_size: usize, - span: tracing::Span, -} - -impl Block { - fn new( - dim: usize, - num_heads: usize, - qkv_bias: bool, - use_rel_pos: bool, - window_size: usize, - input_size: (usize, usize), - vb: VarBuilder, - ) -> Result<Self> { - let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; - let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; - let input_size_attn = if window_size == 0 { - input_size - } else { - (window_size, window_size) - }; - let attn = Attention::new( - dim, - num_heads, - qkv_bias, - use_rel_pos, - input_size_attn, - vb.pp("attn"), - )?; - let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; - let span = tracing::span!(tracing::Level::TRACE, "ie-block"); - Ok(Self { - norm1, - attn, - norm2, - mlp, - window_size, - span, - }) - } -} - -fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> { - let (b, h, w, c) = xs.dims4()?; - let pad_h = (window_size - h % window_size) % window_size; - let pad_w = (window_size - w % window_size) % window_size; - let xs = if pad_h > 0 { - xs.pad_with_zeros(1, 0, pad_h)? - } else { - xs - }; - let xs = if pad_w > 0 { - xs.pad_with_zeros(2, 0, pad_w)? - } else { - xs - }; - let (h_p, w_p) = (h + pad_h, w + pad_w); - let windows = xs - .reshape(( - b, - h_p / window_size, - window_size, - w_p / window_size, - window_size, - c, - ))? - .transpose(2, 3)? - .contiguous()? - .flatten_to(2)?; - Ok((windows, (h_p, w_p))) -} - -fn window_unpartition( - windows: Tensor, - window_size: usize, - (h_p, w_p): (usize, usize), - (h, w): (usize, usize), -) -> Result<Tensor> { - let b = windows.dim(0)? / (h_p * w_p / window_size / window_size); - let xs = windows - .reshape(( - b, - h_p / window_size, - w_p / window_size, - window_size, - window_size, - windows.elem_count() / b / h_p / w_p, - ))? - .transpose(2, 3)? - .contiguous()? - .reshape((b, h_p, w_p, ()))?; - let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs }; - let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs }; - Ok(xs) -} - -impl Module for Block { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let shortcut = xs; - let xs = self.norm1.forward(xs)?; - let hw = (xs.dim(1)?, xs.dim(2)?); - let (xs, pad_hw) = if self.window_size > 0 { - window_partition(xs, self.window_size)? - } else { - (xs, (0, 0)) - }; - let xs = self.attn.forward(&xs)?; - let xs = if self.window_size > 0 { - window_unpartition(xs, self.window_size, pad_hw, hw)? - } else { - xs - }; - let xs = (xs + shortcut)?; - &xs + xs.apply(&self.norm2)?.apply(&self.mlp)? - } -} - -#[derive(Debug)] -pub struct ImageEncoderViT { - patch_embed: PatchEmbed, - blocks: Vec<Block>, - neck_conv1: candle_nn::Conv2d, - neck_ln1: crate::LayerNorm2d, - neck_conv2: candle_nn::Conv2d, - neck_ln2: crate::LayerNorm2d, - pos_embed: Option<Tensor>, - span: tracing::Span, -} - -impl ImageEncoderViT { - #[allow(clippy::too_many_arguments)] - pub fn new( - img_size: usize, - patch_size: usize, - in_chans: usize, - embed_dim: usize, - depth: usize, - num_heads: usize, - out_chans: usize, - qkv_bias: bool, - use_rel_pos: bool, - use_abs_pos: bool, - window_size: usize, - global_attn_indexes: &[usize], - vb: VarBuilder, - ) -> Result<Self> { - let patch_embed = PatchEmbed::new( - in_chans, - embed_dim, - patch_size, - patch_size, - 0, - vb.pp("patch_embed"), - )?; - let mut blocks = Vec::with_capacity(depth); - let vb_b = vb.pp("blocks"); - for i in 0..depth { - let window_size = if global_attn_indexes.contains(&i) { - 0 - } else { - window_size - }; - let block = Block::new( - embed_dim, - num_heads, - qkv_bias, - use_rel_pos, - window_size, - (img_size / patch_size, img_size / patch_size), - vb_b.pp(i), - )?; - blocks.push(block) - } - let neck_conv1 = candle_nn::conv2d_no_bias( - embed_dim, - out_chans, - 1, - Default::default(), - vb.pp("neck.0"), - )?; - let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?; - let cfg = candle_nn::Conv2dConfig { - padding: 1, - ..Default::default() - }; - let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; - let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?; - let pos_embed = if use_abs_pos { - let p = vb.get( - (1, img_size / patch_size, img_size / patch_size, embed_dim), - "pos_embed", - )?; - Some(p) - } else { - None - }; - let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit"); - Ok(Self { - patch_embed, - blocks, - neck_conv1, - neck_ln1, - neck_conv2, - neck_ln2, - pos_embed, - span, - }) - } -} - -impl Module for ImageEncoderViT { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let xs = self.patch_embed.forward(xs)?; - let mut xs = match &self.pos_embed { - Some(pos_embed) => (xs + pos_embed)?, - None => xs, - }; - for block in self.blocks.iter() { - xs = block.forward(&xs)? - } - xs.permute((0, 3, 1, 2))? - .apply(&self.neck_conv1)? - .apply(&self.neck_ln1)? - .apply(&self.neck_conv2)? - .apply(&self.neck_ln2) - } -} |