diff options
author | Edgar Riba <edgar.riba@gmail.com> | 2024-12-21 12:06:03 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-21 12:06:03 +0100 |
commit | 5c2f893e5aa21c9f7c82a00407edb6d76db1d06c (patch) | |
tree | b0f90e5c82f676a8935afbf2db2f468e43d298ad | |
parent | 67cab7d6b8279f953b0a8cc5012b135b9743cdc8 (diff) | |
download | candle-5c2f893e5aa21c9f7c82a00407edb6d76db1d06c.tar.gz candle-5c2f893e5aa21c9f7c82a00407edb6d76db1d06c.tar.bz2 candle-5c2f893e5aa21c9f7c82a00407edb6d76db1d06c.zip |
make DepthAnythingV2 more reusable (#2675)
* make DepthAnythingV2 more reusable
* Fix clippy lints.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
-rw-r--r-- | candle-examples/examples/depth_anything_v2/main.rs | 6 | ||||
-rw-r--r-- | candle-transformers/src/models/depth_anything_v2.rs | 44 |
2 files changed, 27 insertions, 23 deletions
diff --git a/candle-examples/examples/depth_anything_v2/main.rs b/candle-examples/examples/depth_anything_v2/main.rs index ef337eba..2608b40d 100644 --- a/candle-examples/examples/depth_anything_v2/main.rs +++ b/candle-examples/examples/depth_anything_v2/main.rs @@ -6,10 +6,8 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; -use std::ffi::OsString; -use std::path::PathBuf; - use clap::Parser; +use std::{ffi::OsString, path::PathBuf, sync::Arc}; use candle::DType::{F32, U8}; use candle::{DType, Device, Module, Result, Tensor}; @@ -82,7 +80,7 @@ pub fn main() -> anyhow::Result<()> { }; let config = DepthAnythingV2Config::vit_small(); - let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?; + let depth_anything = DepthAnythingV2::new(Arc::new(dinov2), config, vb)?; let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?; diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs index 8eddbf2a..3b6bd1a5 100644 --- a/candle-transformers/src/models/depth_anything_v2.rs +++ b/candle-transformers/src/models/depth_anything_v2.rs @@ -4,6 +4,8 @@ //! - ["Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data"](https://github.com/LiheYoung/Depth-Anything) //! +use std::sync::Arc; + use candle::D::Minus1; use candle::{Module, Result, Tensor}; use candle_nn::ops::Identity; @@ -365,16 +367,18 @@ impl Scratch { const NUM_CHANNELS: usize = 4; -pub struct DPTHead<'a> { - conf: &'a DepthAnythingV2Config, +pub struct DPTHead { projections: Vec<Conv2d>, resize_layers: Vec<Box<dyn Module>>, readout_projections: Vec<Sequential>, scratch: Scratch, + use_class_token: bool, + input_image_size: usize, + target_patch_size: usize, } -impl<'a> DPTHead<'a> { - pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> { +impl DPTHead { + pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> { let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len()); for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() { projections.push(conv2d( @@ -445,20 +449,22 @@ impl<'a> DPTHead<'a> { let scratch = Scratch::new(conf, vb.pp("scratch"))?; Ok(Self { - conf, projections, resize_layers, readout_projections, scratch, + use_class_token: conf.use_class_token, + input_image_size: conf.input_image_size, + target_patch_size: conf.target_patch_size, }) } } -impl Module for DPTHead<'_> { +impl Module for DPTHead { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS); for i in 0..NUM_CHANNELS { - let x = if self.conf.use_class_token { + let x = if self.use_class_token { let x = xs.get(i)?.get(0)?; let class_token = xs.get(i)?.get(1)?; let readout = class_token.unsqueeze(1)?.expand(x.shape())?; @@ -473,8 +479,8 @@ impl Module for DPTHead<'_> { let x = x.permute((0, 2, 1))?.reshape(( x_dims[0], x_dims[x_dims.len() - 1], - self.conf.target_patch_size, - self.conf.target_patch_size, + self.target_patch_size, + self.target_patch_size, ))?; let x = self.projections[i].forward(&x)?; @@ -515,25 +521,25 @@ impl Module for DPTHead<'_> { let out = self.scratch.output_conv1.forward(&path1)?; - let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?; + let out = out.interpolate2d(self.input_image_size, self.input_image_size)?; self.scratch.output_conv2.forward(&out) } } -pub struct DepthAnythingV2<'a> { - pretrained: &'a DinoVisionTransformer, - depth_head: DPTHead<'a>, - conf: &'a DepthAnythingV2Config, +pub struct DepthAnythingV2 { + pretrained: Arc<DinoVisionTransformer>, + depth_head: DPTHead, + conf: DepthAnythingV2Config, } -impl<'a> DepthAnythingV2<'a> { +impl DepthAnythingV2 { pub fn new( - pretrained: &'a DinoVisionTransformer, - conf: &'a DepthAnythingV2Config, + pretrained: Arc<DinoVisionTransformer>, + conf: DepthAnythingV2Config, vb: VarBuilder, ) -> Result<Self> { - let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?; + let depth_head = DPTHead::new(&conf, vb.pp("depth_head"))?; Ok(Self { pretrained, @@ -543,7 +549,7 @@ impl<'a> DepthAnythingV2<'a> { } } -impl Module for DepthAnythingV2<'_> { +impl Module for DepthAnythingV2 { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let features = self.pretrained.get_intermediate_layers( xs, |