summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/depth_anything_v2.rs553
-rw-r--r--candle-transformers/src/models/dinov2.rs78
-rw-r--r--candle-transformers/src/models/mod.rs1
3 files changed, 632 insertions, 0 deletions
diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs
new file mode 100644
index 00000000..9eee6d11
--- /dev/null
+++ b/candle-transformers/src/models/depth_anything_v2.rs
@@ -0,0 +1,553 @@
+use candle::D::Minus1;
+use candle::{Module, Result, Tensor};
+use candle_nn::ops::Identity;
+use candle_nn::{
+ batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm,
+ BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder,
+};
+
+use crate::models::dinov2::DinoVisionTransformer;
+
+pub struct DepthAnythingV2Config {
+ out_channel_sizes: [usize; 4],
+ in_channel_size: usize, // embed_dim in the Dino model
+ num_features: usize,
+ use_batch_norm: bool,
+ use_class_token: bool,
+ layer_ids_vits: Vec<usize>,
+ input_image_size: usize,
+ target_patch_size: usize,
+}
+
+impl DepthAnythingV2Config {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ out_channel_sizes: [usize; 4],
+ in_channel_size: usize,
+ num_features: usize,
+ use_batch_norm: bool,
+ use_class_token: bool,
+ layer_ids_vits: Vec<usize>,
+ input_image_size: usize,
+ target_patch_size: usize,
+ ) -> Self {
+ Self {
+ out_channel_sizes,
+ in_channel_size,
+ num_features,
+ use_batch_norm,
+ use_class_token,
+ layer_ids_vits,
+ input_image_size,
+ target_patch_size,
+ }
+ }
+
+ pub fn vit_small() -> Self {
+ Self {
+ out_channel_sizes: [48, 96, 192, 384],
+ in_channel_size: 384,
+ num_features: 64,
+ use_batch_norm: false,
+ use_class_token: false,
+ layer_ids_vits: vec![2, 5, 8, 11],
+ input_image_size: 518,
+ target_patch_size: 518 / 14,
+ }
+ }
+
+ pub fn vit_base() -> Self {
+ Self {
+ out_channel_sizes: [96, 192, 384, 768],
+ in_channel_size: 768,
+ num_features: 128,
+ use_batch_norm: false,
+ use_class_token: false,
+ layer_ids_vits: vec![2, 5, 8, 11],
+ input_image_size: 518,
+ target_patch_size: 518 / 14,
+ }
+ }
+
+ pub fn vit_large() -> Self {
+ Self {
+ out_channel_sizes: [256, 512, 1024, 1024],
+ in_channel_size: 1024,
+ num_features: 256,
+ use_batch_norm: false,
+ use_class_token: false,
+ layer_ids_vits: vec![4, 11, 17, 23],
+ input_image_size: 518,
+ target_patch_size: 518 / 14,
+ }
+ }
+
+ pub fn vit_giant() -> Self {
+ Self {
+ out_channel_sizes: [1536, 1536, 1536, 1536],
+ in_channel_size: 1536,
+ num_features: 384,
+ use_batch_norm: false,
+ use_class_token: false,
+ layer_ids_vits: vec![9, 19, 29, 39],
+ input_image_size: 518,
+ target_patch_size: 518 / 14,
+ }
+ }
+}
+
+pub struct ResidualConvUnit {
+ activation: Activation,
+ conv1: Conv2d,
+ conv2: Conv2d,
+ batch_norm1: Option<BatchNorm>,
+ batch_norm2: Option<BatchNorm>,
+}
+
+impl ResidualConvUnit {
+ pub fn new(
+ conf: &DepthAnythingV2Config,
+ activation: Activation,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ const KERNEL_SIZE: usize = 3;
+ let conv_cfg = Conv2dConfig {
+ padding: 1,
+ stride: 1,
+ dilation: 1,
+ groups: 1,
+ };
+ let conv1 = conv2d(
+ conf.num_features,
+ conf.num_features,
+ KERNEL_SIZE,
+ conv_cfg,
+ vb.pp("conv1"),
+ )?;
+ let conv2 = conv2d(
+ conf.num_features,
+ conf.num_features,
+ KERNEL_SIZE,
+ conv_cfg,
+ vb.pp("conv2"),
+ )?;
+
+ let (batch_norm1, batch_norm2) = match conf.use_batch_norm {
+ true => {
+ let batch_norm_cfg = BatchNormConfig {
+ eps: 1e-05,
+ remove_mean: false,
+ affine: true,
+ momentum: 0.1,
+ };
+ (
+ Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?),
+ Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?),
+ )
+ }
+ false => (None, None),
+ };
+
+ Ok(Self {
+ activation,
+ conv1,
+ conv2,
+ batch_norm1,
+ batch_norm2,
+ })
+ }
+}
+
+impl Module for ResidualConvUnit {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let out = self.activation.forward(xs)?;
+ let out = self.conv1.forward(&out)?;
+ let out = if let Some(batch_norm1) = &self.batch_norm1 {
+ batch_norm1.forward_train(&out)?
+ } else {
+ out
+ };
+
+ let out = self.activation.forward(&out)?;
+ let out = self.conv2.forward(&out)?;
+ let out = if let Some(batch_norm2) = &self.batch_norm2 {
+ batch_norm2.forward_train(&out)?
+ } else {
+ out
+ };
+
+ out + xs
+ }
+}
+
+pub struct FeatureFusionBlock {
+ res_conv_unit1: ResidualConvUnit,
+ res_conv_unit2: ResidualConvUnit,
+ output_conv: Conv2d,
+ target_patch_size: usize,
+}
+
+impl FeatureFusionBlock {
+ pub fn new(
+ conf: &DepthAnythingV2Config,
+ target_patch_size: usize,
+ activation: Activation,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ const KERNEL_SIZE: usize = 1;
+ let conv_cfg = Conv2dConfig {
+ padding: 0,
+ stride: 1,
+ dilation: 1,
+ groups: 1,
+ };
+ let output_conv = conv2d(
+ conf.num_features,
+ conf.num_features,
+ KERNEL_SIZE,
+ conv_cfg,
+ vb.pp("out_conv"),
+ )?;
+ let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?;
+ let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?;
+
+ Ok(Self {
+ res_conv_unit1,
+ res_conv_unit2,
+ output_conv,
+ target_patch_size,
+ })
+ }
+}
+
+impl Module for FeatureFusionBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let out = self.res_conv_unit2.forward(xs)?;
+ let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?;
+
+ self.output_conv.forward(&out)
+ }
+}
+
+pub struct Scratch {
+ layer1_rn: Conv2d,
+ layer2_rn: Conv2d,
+ layer3_rn: Conv2d,
+ layer4_rn: Conv2d,
+ refine_net1: FeatureFusionBlock,
+ refine_net2: FeatureFusionBlock,
+ refine_net3: FeatureFusionBlock,
+ refine_net4: FeatureFusionBlock,
+ output_conv1: Conv2d,
+ output_conv2: Sequential,
+}
+
+impl Scratch {
+ pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
+ const KERNEL_SIZE: usize = 3;
+ let conv_cfg = Conv2dConfig {
+ padding: 1,
+ stride: 1,
+ dilation: 1,
+ groups: 1,
+ };
+
+ let layer1_rn = conv2d_no_bias(
+ conf.out_channel_sizes[0],
+ conf.num_features,
+ KERNEL_SIZE,
+ conv_cfg,
+ vb.pp("layer1_rn"),
+ )?;
+ let layer2_rn = conv2d_no_bias(
+ conf.out_channel_sizes[1],
+ conf.num_features,
+ KERNEL_SIZE,
+ conv_cfg,
+ vb.pp("layer2_rn"),
+ )?;
+ let layer3_rn = conv2d_no_bias(
+ conf.out_channel_sizes[2],
+ conf.num_features,
+ KERNEL_SIZE,
+ conv_cfg,
+ vb.pp("layer3_rn"),
+ )?;
+ let layer4_rn = conv2d_no_bias(
+ conf.out_channel_sizes[3],
+ conf.num_features,
+ KERNEL_SIZE,
+ conv_cfg,
+ vb.pp("layer4_rn"),
+ )?;
+
+ let refine_net1 = FeatureFusionBlock::new(
+ conf,
+ conf.target_patch_size * 8,
+ Activation::Relu,
+ vb.pp("refinenet1"),
+ )?;
+ let refine_net2 = FeatureFusionBlock::new(
+ conf,
+ conf.target_patch_size * 4,
+ Activation::Relu,
+ vb.pp("refinenet2"),
+ )?;
+ let refine_net3 = FeatureFusionBlock::new(
+ conf,
+ conf.target_patch_size * 2,
+ Activation::Relu,
+ vb.pp("refinenet3"),
+ )?;
+ let refine_net4 = FeatureFusionBlock::new(
+ conf,
+ conf.target_patch_size,
+ Activation::Relu,
+ vb.pp("refinenet4"),
+ )?;
+
+ let conv_cfg = Conv2dConfig {
+ padding: 1,
+ stride: 1,
+ dilation: 1,
+ groups: 1,
+ };
+ let output_conv1 = conv2d(
+ conf.num_features,
+ conf.num_features / 2,
+ KERNEL_SIZE,
+ conv_cfg,
+ vb.pp("output_conv1"),
+ )?;
+
+ let output_conv2 = seq();
+ const HEAD_FEATURES_2: usize = 32;
+ const OUT_CHANNELS_2: usize = 1;
+ const KERNEL_SIZE_2: usize = 1;
+ let output_conv2 = output_conv2.add(conv2d(
+ conf.num_features / 2,
+ HEAD_FEATURES_2,
+ KERNEL_SIZE,
+ conv_cfg,
+ vb.pp("output_conv2").pp("0"),
+ )?);
+ let output_conv2 = output_conv2
+ .add(Activation::Relu)
+ .add(conv2d(
+ HEAD_FEATURES_2,
+ OUT_CHANNELS_2,
+ KERNEL_SIZE_2,
+ conv_cfg,
+ vb.pp("output_conv2").pp("2"),
+ )?)
+ .add(Activation::Relu);
+
+ Ok(Self {
+ layer1_rn,
+ layer2_rn,
+ layer3_rn,
+ layer4_rn,
+ refine_net1,
+ refine_net2,
+ refine_net3,
+ refine_net4,
+ output_conv1,
+ output_conv2,
+ })
+ }
+}
+
+const NUM_CHANNELS: usize = 4;
+
+pub struct DPTHead<'a> {
+ conf: &'a DepthAnythingV2Config,
+ projections: Vec<Conv2d>,
+ resize_layers: Vec<Box<dyn Module>>,
+ readout_projections: Vec<Sequential>,
+ scratch: Scratch,
+}
+
+impl<'a> DPTHead<'a> {
+ pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
+ let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
+ for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
+ projections.push(conv2d(
+ conf.in_channel_size,
+ *out_channel_size,
+ 1,
+ Default::default(),
+ vb.pp("projects").pp(conv_index.to_string()),
+ )?);
+ }
+
+ let resize_layers: Vec<Box<dyn Module>> = vec![
+ Box::new(conv_transpose2d(
+ conf.out_channel_sizes[0],
+ conf.out_channel_sizes[0],
+ 4,
+ ConvTranspose2dConfig {
+ padding: 0,
+ stride: 4,
+ dilation: 1,
+ output_padding: 0,
+ },
+ vb.pp("resize_layers").pp("0"),
+ )?),
+ Box::new(conv_transpose2d(
+ conf.out_channel_sizes[1],
+ conf.out_channel_sizes[1],
+ 2,
+ ConvTranspose2dConfig {
+ padding: 0,
+ stride: 2,
+ dilation: 1,
+ output_padding: 0,
+ },
+ vb.pp("resize_layers").pp("1"),
+ )?),
+ Box::new(Identity::new()),
+ Box::new(conv2d(
+ conf.out_channel_sizes[3],
+ conf.out_channel_sizes[3],
+ 3,
+ Conv2dConfig {
+ padding: 1,
+ stride: 2,
+ dilation: 1,
+ groups: 1,
+ },
+ vb.pp("resize_layers").pp("3"),
+ )?),
+ ];
+
+ let readout_projections = if conf.use_class_token {
+ let rop = Vec::with_capacity(NUM_CHANNELS);
+ for rop_index in 0..NUM_CHANNELS {
+ seq()
+ .add(linear(
+ 2 * conf.in_channel_size,
+ conf.in_channel_size,
+ vb.pp("readout_projects").pp(rop_index.to_string()),
+ )?)
+ .add(Activation::Gelu);
+ }
+ rop
+ } else {
+ vec![]
+ };
+
+ let scratch = Scratch::new(conf, vb.pp("scratch"))?;
+
+ Ok(Self {
+ conf,
+ projections,
+ resize_layers,
+ readout_projections,
+ scratch,
+ })
+ }
+}
+
+impl Module for DPTHead<'_> {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
+ for i in 0..NUM_CHANNELS {
+ let x = if self.conf.use_class_token {
+ let x = xs.get(i)?.get(0)?;
+ let class_token = xs.get(i)?.get(1)?;
+ let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
+ let to_cat = [x, readout];
+ let cat = Tensor::cat(&to_cat, Minus1)?;
+ self.readout_projections[i].forward(&cat)?
+ } else {
+ xs.get(i)?
+ };
+ let x_dims = x.dims();
+
+ let x = x.permute((0, 2, 1))?.reshape((
+ x_dims[0],
+ x_dims[x_dims.len() - 1],
+ self.conf.target_patch_size,
+ self.conf.target_patch_size,
+ ))?;
+ let x = self.projections[i].forward(&x)?;
+
+ let x = self.resize_layers[i].forward(&x)?;
+ out.push(x);
+ }
+
+ let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?;
+ let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?;
+ let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?;
+ let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?;
+
+ let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?;
+
+ let res3_out = self
+ .scratch
+ .refine_net3
+ .res_conv_unit1
+ .forward(&layer_3_rn)?;
+ let res3_out = path4.add(&res3_out)?;
+ let path3 = self.scratch.refine_net3.forward(&res3_out)?;
+
+ let res2_out = self
+ .scratch
+ .refine_net2
+ .res_conv_unit1
+ .forward(&layer_2_rn)?;
+ let res2_out = path3.add(&res2_out)?;
+ let path2 = self.scratch.refine_net2.forward(&res2_out)?;
+
+ let res1_out = self
+ .scratch
+ .refine_net1
+ .res_conv_unit1
+ .forward(&layer_1_rn)?;
+ let res1_out = path2.add(&res1_out)?;
+ let path1 = self.scratch.refine_net1.forward(&res1_out)?;
+
+ let out = self.scratch.output_conv1.forward(&path1)?;
+
+ let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
+
+ self.scratch.output_conv2.forward(&out)
+ }
+}
+
+pub struct DepthAnythingV2<'a> {
+ pretrained: &'a DinoVisionTransformer,
+ depth_head: DPTHead<'a>,
+ conf: &'a DepthAnythingV2Config,
+}
+
+impl<'a> DepthAnythingV2<'a> {
+ pub fn new(
+ pretrained: &'a DinoVisionTransformer,
+ conf: &'a DepthAnythingV2Config,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
+
+ Ok(Self {
+ pretrained,
+ depth_head,
+ conf,
+ })
+ }
+}
+
+impl<'a> Module for DepthAnythingV2<'a> {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let features = self.pretrained.get_intermediate_layers(
+ xs,
+ &self.conf.layer_ids_vits,
+ false,
+ false,
+ true,
+ )?;
+ let depth = self.depth_head.forward(&features)?;
+
+ depth.relu()
+ }
+}
diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs
index 757aa88a..00e501ce 100644
--- a/candle-transformers/src/models/dinov2.rs
+++ b/candle-transformers/src/models/dinov2.rs
@@ -258,6 +258,84 @@ impl DinoVisionTransformer {
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
+
+ fn get_intermediate_layers_not_chunked(
+ &self,
+ xs: &Tensor,
+ blocks_to_take: &[usize],
+ ) -> Result<Vec<Tensor>> {
+ let mut xs = self.prepare_tokens_with_mask(xs)?;
+ let mut output = Vec::new();
+ for (i, blk) in self.blocks.iter().enumerate() {
+ xs = blk.forward(&xs)?;
+ if blocks_to_take.contains(&i) {
+ output.push(xs.clone());
+ }
+ }
+ if output.len() != blocks_to_take.len() {
+ candle::bail!(
+ "only {} / {} blocks found",
+ output.len(),
+ blocks_to_take.len()
+ );
+ }
+ Ok(output)
+ }
+
+ pub fn get_intermediate_layers(
+ &self,
+ xs: &Tensor,
+ blocks_to_take: &[usize],
+ reshape: bool,
+ return_class_token: bool,
+ norm: bool,
+ ) -> Result<Tensor> {
+ let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
+ let outputs = if norm {
+ outputs
+ .iter()
+ .map(|out| self.norm.forward(out))
+ .collect::<Result<Vec<_>>>()?
+ } else {
+ outputs
+ };
+ let class_tokens = outputs
+ .iter()
+ .map(|out| out.i((.., 0)))
+ .collect::<Result<Vec<_>>>()?;
+ let outputs = outputs
+ .iter()
+ .map(|out| out.i((.., 1..)))
+ .collect::<Result<Vec<_>>>()?;
+
+ let outputs = if reshape {
+ let (b, _c, w, h) = xs.dims4()?;
+ let patch_size = self.patch_embed.patch_size.0;
+ let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
+ outputs
+ .iter()
+ .map(|out| {
+ out.reshape((b, w / patch_size, h / patch_size, num_channels))?
+ .transpose(2, 3)?
+ .transpose(1, 2)
+ })
+ .collect::<Result<Vec<_>>>()?
+ } else {
+ outputs
+ };
+
+ let outputs = if return_class_token {
+ outputs
+ .iter()
+ .zip(class_tokens.iter())
+ .map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
+ .collect::<Result<Vec<_>>>()?
+ } else {
+ outputs
+ };
+
+ Tensor::stack(&outputs[..], 0)
+ }
}
impl Module for DinoVisionTransformer {
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 4628a3de..89ae0f8a 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -6,6 +6,7 @@ pub mod chatglm;
pub mod clip;
pub mod convmixer;
pub mod convnext;
+pub mod depth_anything_v2;
pub mod dinov2;
pub mod distilbert;
pub mod efficientnet;