diff options
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/models/depth_anything_v2.rs | 553 | ||||
-rw-r--r-- | candle-transformers/src/models/dinov2.rs | 78 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 |
3 files changed, 632 insertions, 0 deletions
diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs new file mode 100644 index 00000000..9eee6d11 --- /dev/null +++ b/candle-transformers/src/models/depth_anything_v2.rs @@ -0,0 +1,553 @@ +use candle::D::Minus1; +use candle::{Module, Result, Tensor}; +use candle_nn::ops::Identity; +use candle_nn::{ + batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm, + BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder, +}; + +use crate::models::dinov2::DinoVisionTransformer; + +pub struct DepthAnythingV2Config { + out_channel_sizes: [usize; 4], + in_channel_size: usize, // embed_dim in the Dino model + num_features: usize, + use_batch_norm: bool, + use_class_token: bool, + layer_ids_vits: Vec<usize>, + input_image_size: usize, + target_patch_size: usize, +} + +impl DepthAnythingV2Config { + #[allow(clippy::too_many_arguments)] + pub fn new( + out_channel_sizes: [usize; 4], + in_channel_size: usize, + num_features: usize, + use_batch_norm: bool, + use_class_token: bool, + layer_ids_vits: Vec<usize>, + input_image_size: usize, + target_patch_size: usize, + ) -> Self { + Self { + out_channel_sizes, + in_channel_size, + num_features, + use_batch_norm, + use_class_token, + layer_ids_vits, + input_image_size, + target_patch_size, + } + } + + pub fn vit_small() -> Self { + Self { + out_channel_sizes: [48, 96, 192, 384], + in_channel_size: 384, + num_features: 64, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![2, 5, 8, 11], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_base() -> Self { + Self { + out_channel_sizes: [96, 192, 384, 768], + in_channel_size: 768, + num_features: 128, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![2, 5, 8, 11], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_large() -> Self { + Self { + out_channel_sizes: [256, 512, 1024, 1024], + in_channel_size: 1024, + num_features: 256, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![4, 11, 17, 23], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_giant() -> Self { + Self { + out_channel_sizes: [1536, 1536, 1536, 1536], + in_channel_size: 1536, + num_features: 384, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![9, 19, 29, 39], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } +} + +pub struct ResidualConvUnit { + activation: Activation, + conv1: Conv2d, + conv2: Conv2d, + batch_norm1: Option<BatchNorm>, + batch_norm2: Option<BatchNorm>, +} + +impl ResidualConvUnit { + pub fn new( + conf: &DepthAnythingV2Config, + activation: Activation, + vb: VarBuilder, + ) -> Result<Self> { + const KERNEL_SIZE: usize = 3; + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + let conv1 = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("conv1"), + )?; + let conv2 = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("conv2"), + )?; + + let (batch_norm1, batch_norm2) = match conf.use_batch_norm { + true => { + let batch_norm_cfg = BatchNormConfig { + eps: 1e-05, + remove_mean: false, + affine: true, + momentum: 0.1, + }; + ( + Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?), + Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?), + ) + } + false => (None, None), + }; + + Ok(Self { + activation, + conv1, + conv2, + batch_norm1, + batch_norm2, + }) + } +} + +impl Module for ResidualConvUnit { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let out = self.activation.forward(xs)?; + let out = self.conv1.forward(&out)?; + let out = if let Some(batch_norm1) = &self.batch_norm1 { + batch_norm1.forward_train(&out)? + } else { + out + }; + + let out = self.activation.forward(&out)?; + let out = self.conv2.forward(&out)?; + let out = if let Some(batch_norm2) = &self.batch_norm2 { + batch_norm2.forward_train(&out)? + } else { + out + }; + + out + xs + } +} + +pub struct FeatureFusionBlock { + res_conv_unit1: ResidualConvUnit, + res_conv_unit2: ResidualConvUnit, + output_conv: Conv2d, + target_patch_size: usize, +} + +impl FeatureFusionBlock { + pub fn new( + conf: &DepthAnythingV2Config, + target_patch_size: usize, + activation: Activation, + vb: VarBuilder, + ) -> Result<Self> { + const KERNEL_SIZE: usize = 1; + let conv_cfg = Conv2dConfig { + padding: 0, + stride: 1, + dilation: 1, + groups: 1, + }; + let output_conv = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("out_conv"), + )?; + let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?; + let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?; + + Ok(Self { + res_conv_unit1, + res_conv_unit2, + output_conv, + target_patch_size, + }) + } +} + +impl Module for FeatureFusionBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let out = self.res_conv_unit2.forward(xs)?; + let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?; + + self.output_conv.forward(&out) + } +} + +pub struct Scratch { + layer1_rn: Conv2d, + layer2_rn: Conv2d, + layer3_rn: Conv2d, + layer4_rn: Conv2d, + refine_net1: FeatureFusionBlock, + refine_net2: FeatureFusionBlock, + refine_net3: FeatureFusionBlock, + refine_net4: FeatureFusionBlock, + output_conv1: Conv2d, + output_conv2: Sequential, +} + +impl Scratch { + pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> { + const KERNEL_SIZE: usize = 3; + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + + let layer1_rn = conv2d_no_bias( + conf.out_channel_sizes[0], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer1_rn"), + )?; + let layer2_rn = conv2d_no_bias( + conf.out_channel_sizes[1], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer2_rn"), + )?; + let layer3_rn = conv2d_no_bias( + conf.out_channel_sizes[2], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer3_rn"), + )?; + let layer4_rn = conv2d_no_bias( + conf.out_channel_sizes[3], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer4_rn"), + )?; + + let refine_net1 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 8, + Activation::Relu, + vb.pp("refinenet1"), + )?; + let refine_net2 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 4, + Activation::Relu, + vb.pp("refinenet2"), + )?; + let refine_net3 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 2, + Activation::Relu, + vb.pp("refinenet3"), + )?; + let refine_net4 = FeatureFusionBlock::new( + conf, + conf.target_patch_size, + Activation::Relu, + vb.pp("refinenet4"), + )?; + + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + let output_conv1 = conv2d( + conf.num_features, + conf.num_features / 2, + KERNEL_SIZE, + conv_cfg, + vb.pp("output_conv1"), + )?; + + let output_conv2 = seq(); + const HEAD_FEATURES_2: usize = 32; + const OUT_CHANNELS_2: usize = 1; + const KERNEL_SIZE_2: usize = 1; + let output_conv2 = output_conv2.add(conv2d( + conf.num_features / 2, + HEAD_FEATURES_2, + KERNEL_SIZE, + conv_cfg, + vb.pp("output_conv2").pp("0"), + )?); + let output_conv2 = output_conv2 + .add(Activation::Relu) + .add(conv2d( + HEAD_FEATURES_2, + OUT_CHANNELS_2, + KERNEL_SIZE_2, + conv_cfg, + vb.pp("output_conv2").pp("2"), + )?) + .add(Activation::Relu); + + Ok(Self { + layer1_rn, + layer2_rn, + layer3_rn, + layer4_rn, + refine_net1, + refine_net2, + refine_net3, + refine_net4, + output_conv1, + output_conv2, + }) + } +} + +const NUM_CHANNELS: usize = 4; + +pub struct DPTHead<'a> { + conf: &'a DepthAnythingV2Config, + projections: Vec<Conv2d>, + resize_layers: Vec<Box<dyn Module>>, + readout_projections: Vec<Sequential>, + scratch: Scratch, +} + +impl<'a> DPTHead<'a> { + pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> { + let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len()); + for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() { + projections.push(conv2d( + conf.in_channel_size, + *out_channel_size, + 1, + Default::default(), + vb.pp("projects").pp(conv_index.to_string()), + )?); + } + + let resize_layers: Vec<Box<dyn Module>> = vec![ + Box::new(conv_transpose2d( + conf.out_channel_sizes[0], + conf.out_channel_sizes[0], + 4, + ConvTranspose2dConfig { + padding: 0, + stride: 4, + dilation: 1, + output_padding: 0, + }, + vb.pp("resize_layers").pp("0"), + )?), + Box::new(conv_transpose2d( + conf.out_channel_sizes[1], + conf.out_channel_sizes[1], + 2, + ConvTranspose2dConfig { + padding: 0, + stride: 2, + dilation: 1, + output_padding: 0, + }, + vb.pp("resize_layers").pp("1"), + )?), + Box::new(Identity::new()), + Box::new(conv2d( + conf.out_channel_sizes[3], + conf.out_channel_sizes[3], + 3, + Conv2dConfig { + padding: 1, + stride: 2, + dilation: 1, + groups: 1, + }, + vb.pp("resize_layers").pp("3"), + )?), + ]; + + let readout_projections = if conf.use_class_token { + let rop = Vec::with_capacity(NUM_CHANNELS); + for rop_index in 0..NUM_CHANNELS { + seq() + .add(linear( + 2 * conf.in_channel_size, + conf.in_channel_size, + vb.pp("readout_projects").pp(rop_index.to_string()), + )?) + .add(Activation::Gelu); + } + rop + } else { + vec![] + }; + + let scratch = Scratch::new(conf, vb.pp("scratch"))?; + + Ok(Self { + conf, + projections, + resize_layers, + readout_projections, + scratch, + }) + } +} + +impl Module for DPTHead<'_> { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS); + for i in 0..NUM_CHANNELS { + let x = if self.conf.use_class_token { + let x = xs.get(i)?.get(0)?; + let class_token = xs.get(i)?.get(1)?; + let readout = class_token.unsqueeze(1)?.expand(x.shape())?; + let to_cat = [x, readout]; + let cat = Tensor::cat(&to_cat, Minus1)?; + self.readout_projections[i].forward(&cat)? + } else { + xs.get(i)? + }; + let x_dims = x.dims(); + + let x = x.permute((0, 2, 1))?.reshape(( + x_dims[0], + x_dims[x_dims.len() - 1], + self.conf.target_patch_size, + self.conf.target_patch_size, + ))?; + let x = self.projections[i].forward(&x)?; + + let x = self.resize_layers[i].forward(&x)?; + out.push(x); + } + + let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?; + let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?; + let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?; + let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?; + + let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?; + + let res3_out = self + .scratch + .refine_net3 + .res_conv_unit1 + .forward(&layer_3_rn)?; + let res3_out = path4.add(&res3_out)?; + let path3 = self.scratch.refine_net3.forward(&res3_out)?; + + let res2_out = self + .scratch + .refine_net2 + .res_conv_unit1 + .forward(&layer_2_rn)?; + let res2_out = path3.add(&res2_out)?; + let path2 = self.scratch.refine_net2.forward(&res2_out)?; + + let res1_out = self + .scratch + .refine_net1 + .res_conv_unit1 + .forward(&layer_1_rn)?; + let res1_out = path2.add(&res1_out)?; + let path1 = self.scratch.refine_net1.forward(&res1_out)?; + + let out = self.scratch.output_conv1.forward(&path1)?; + + let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?; + + self.scratch.output_conv2.forward(&out) + } +} + +pub struct DepthAnythingV2<'a> { + pretrained: &'a DinoVisionTransformer, + depth_head: DPTHead<'a>, + conf: &'a DepthAnythingV2Config, +} + +impl<'a> DepthAnythingV2<'a> { + pub fn new( + pretrained: &'a DinoVisionTransformer, + conf: &'a DepthAnythingV2Config, + vb: VarBuilder, + ) -> Result<Self> { + let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?; + + Ok(Self { + pretrained, + depth_head, + conf, + }) + } +} + +impl<'a> Module for DepthAnythingV2<'a> { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let features = self.pretrained.get_intermediate_layers( + xs, + &self.conf.layer_ids_vits, + false, + false, + true, + )?; + let depth = self.depth_head.forward(&features)?; + + depth.relu() + } +} diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index 757aa88a..00e501ce 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -258,6 +258,84 @@ impl DinoVisionTransformer { let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; &xs + &self.interpolate_pos_encoding(&xs, w, h)? } + + fn get_intermediate_layers_not_chunked( + &self, + xs: &Tensor, + blocks_to_take: &[usize], + ) -> Result<Vec<Tensor>> { + let mut xs = self.prepare_tokens_with_mask(xs)?; + let mut output = Vec::new(); + for (i, blk) in self.blocks.iter().enumerate() { + xs = blk.forward(&xs)?; + if blocks_to_take.contains(&i) { + output.push(xs.clone()); + } + } + if output.len() != blocks_to_take.len() { + candle::bail!( + "only {} / {} blocks found", + output.len(), + blocks_to_take.len() + ); + } + Ok(output) + } + + pub fn get_intermediate_layers( + &self, + xs: &Tensor, + blocks_to_take: &[usize], + reshape: bool, + return_class_token: bool, + norm: bool, + ) -> Result<Tensor> { + let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?; + let outputs = if norm { + outputs + .iter() + .map(|out| self.norm.forward(out)) + .collect::<Result<Vec<_>>>()? + } else { + outputs + }; + let class_tokens = outputs + .iter() + .map(|out| out.i((.., 0))) + .collect::<Result<Vec<_>>>()?; + let outputs = outputs + .iter() + .map(|out| out.i((.., 1..))) + .collect::<Result<Vec<_>>>()?; + + let outputs = if reshape { + let (b, _c, w, h) = xs.dims4()?; + let patch_size = self.patch_embed.patch_size.0; + let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size)); + outputs + .iter() + .map(|out| { + out.reshape((b, w / patch_size, h / patch_size, num_channels))? + .transpose(2, 3)? + .transpose(1, 2) + }) + .collect::<Result<Vec<_>>>()? + } else { + outputs + }; + + let outputs = if return_class_token { + outputs + .iter() + .zip(class_tokens.iter()) + .map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1)) + .collect::<Result<Vec<_>>>()? + } else { + outputs + }; + + Tensor::stack(&outputs[..], 0) + } } impl Module for DinoVisionTransformer { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 4628a3de..89ae0f8a 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -6,6 +6,7 @@ pub mod chatglm; pub mod clip; pub mod convmixer; pub mod convnext; +pub mod depth_anything_v2; pub mod dinov2; pub mod distilbert; pub mod efficientnet; |