diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/yolo-v8/model.rs | 9 |
1 files changed, 2 insertions, 7 deletions
diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs index bf48fd84..cecd4ce6 100644 --- a/candle-examples/examples/yolo-v8/model.rs +++ b/candle-examples/examples/yolo-v8/model.rs @@ -1,7 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{ - batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder, -}; +use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder}; #[derive(Clone, Copy, PartialEq, Debug)] pub struct Multiples { @@ -76,7 +74,6 @@ impl Module for Upsample { #[derive(Debug)] struct ConvBlock { conv: Conv2d, - bn: BatchNorm, span: tracing::Span, } @@ -96,11 +93,10 @@ impl ConvBlock { groups: 1, dilation: 1, }; - let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; + let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?; Ok(Self { conv, - bn, span: tracing::span!(tracing::Level::TRACE, "conv-block"), }) } @@ -110,7 +106,6 @@ impl Module for ConvBlock { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let xs = self.conv.forward(xs)?; - let xs = self.bn.forward(&xs)?; candle_nn::ops::silu(&xs) } } |