summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEdgar Riba <edgar.riba@gmail.com>2024-12-21 12:06:03 +0100
committerGitHub <noreply@github.com>2024-12-21 12:06:03 +0100
commit5c2f893e5aa21c9f7c82a00407edb6d76db1d06c (patch)
treeb0f90e5c82f676a8935afbf2db2f468e43d298ad
parent67cab7d6b8279f953b0a8cc5012b135b9743cdc8 (diff)
downloadcandle-5c2f893e5aa21c9f7c82a00407edb6d76db1d06c.tar.gz
candle-5c2f893e5aa21c9f7c82a00407edb6d76db1d06c.tar.bz2
candle-5c2f893e5aa21c9f7c82a00407edb6d76db1d06c.zip
make DepthAnythingV2 more reusable (#2675)
* make DepthAnythingV2 more reusable * Fix clippy lints. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
-rw-r--r--candle-examples/examples/depth_anything_v2/main.rs6
-rw-r--r--candle-transformers/src/models/depth_anything_v2.rs44
2 files changed, 27 insertions, 23 deletions
diff --git a/candle-examples/examples/depth_anything_v2/main.rs b/candle-examples/examples/depth_anything_v2/main.rs
index ef337eba..2608b40d 100644
--- a/candle-examples/examples/depth_anything_v2/main.rs
+++ b/candle-examples/examples/depth_anything_v2/main.rs
@@ -6,10 +6,8 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
-use std::ffi::OsString;
-use std::path::PathBuf;
-
use clap::Parser;
+use std::{ffi::OsString, path::PathBuf, sync::Arc};
use candle::DType::{F32, U8};
use candle::{DType, Device, Module, Result, Tensor};
@@ -82,7 +80,7 @@ pub fn main() -> anyhow::Result<()> {
};
let config = DepthAnythingV2Config::vit_small();
- let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
+ let depth_anything = DepthAnythingV2::new(Arc::new(dinov2), config, vb)?;
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs
index 8eddbf2a..3b6bd1a5 100644
--- a/candle-transformers/src/models/depth_anything_v2.rs
+++ b/candle-transformers/src/models/depth_anything_v2.rs
@@ -4,6 +4,8 @@
//! - ["Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data"](https://github.com/LiheYoung/Depth-Anything)
//!
+use std::sync::Arc;
+
use candle::D::Minus1;
use candle::{Module, Result, Tensor};
use candle_nn::ops::Identity;
@@ -365,16 +367,18 @@ impl Scratch {
const NUM_CHANNELS: usize = 4;
-pub struct DPTHead<'a> {
- conf: &'a DepthAnythingV2Config,
+pub struct DPTHead {
projections: Vec<Conv2d>,
resize_layers: Vec<Box<dyn Module>>,
readout_projections: Vec<Sequential>,
scratch: Scratch,
+ use_class_token: bool,
+ input_image_size: usize,
+ target_patch_size: usize,
}
-impl<'a> DPTHead<'a> {
- pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
+impl DPTHead {
+ pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
projections.push(conv2d(
@@ -445,20 +449,22 @@ impl<'a> DPTHead<'a> {
let scratch = Scratch::new(conf, vb.pp("scratch"))?;
Ok(Self {
- conf,
projections,
resize_layers,
readout_projections,
scratch,
+ use_class_token: conf.use_class_token,
+ input_image_size: conf.input_image_size,
+ target_patch_size: conf.target_patch_size,
})
}
}
-impl Module for DPTHead<'_> {
+impl Module for DPTHead {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
for i in 0..NUM_CHANNELS {
- let x = if self.conf.use_class_token {
+ let x = if self.use_class_token {
let x = xs.get(i)?.get(0)?;
let class_token = xs.get(i)?.get(1)?;
let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
@@ -473,8 +479,8 @@ impl Module for DPTHead<'_> {
let x = x.permute((0, 2, 1))?.reshape((
x_dims[0],
x_dims[x_dims.len() - 1],
- self.conf.target_patch_size,
- self.conf.target_patch_size,
+ self.target_patch_size,
+ self.target_patch_size,
))?;
let x = self.projections[i].forward(&x)?;
@@ -515,25 +521,25 @@ impl Module for DPTHead<'_> {
let out = self.scratch.output_conv1.forward(&path1)?;
- let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
+ let out = out.interpolate2d(self.input_image_size, self.input_image_size)?;
self.scratch.output_conv2.forward(&out)
}
}
-pub struct DepthAnythingV2<'a> {
- pretrained: &'a DinoVisionTransformer,
- depth_head: DPTHead<'a>,
- conf: &'a DepthAnythingV2Config,
+pub struct DepthAnythingV2 {
+ pretrained: Arc<DinoVisionTransformer>,
+ depth_head: DPTHead,
+ conf: DepthAnythingV2Config,
}
-impl<'a> DepthAnythingV2<'a> {
+impl DepthAnythingV2 {
pub fn new(
- pretrained: &'a DinoVisionTransformer,
- conf: &'a DepthAnythingV2Config,
+ pretrained: Arc<DinoVisionTransformer>,
+ conf: DepthAnythingV2Config,
vb: VarBuilder,
) -> Result<Self> {
- let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
+ let depth_head = DPTHead::new(&conf, vb.pp("depth_head"))?;
Ok(Self {
pretrained,
@@ -543,7 +549,7 @@ impl<'a> DepthAnythingV2<'a> {
}
}
-impl Module for DepthAnythingV2<'_> {
+impl Module for DepthAnythingV2 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let features = self.pretrained.get_intermediate_layers(
xs,