diff options
-rw-r--r-- | candle-examples/examples/yolo-v8/model.rs | 9 | ||||
-rw-r--r-- | candle-nn/src/batch_norm.rs | 4 | ||||
-rw-r--r-- | candle-nn/src/conv.rs | 21 |
3 files changed, 27 insertions, 7 deletions
diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs index bf48fd84..cecd4ce6 100644 --- a/candle-examples/examples/yolo-v8/model.rs +++ b/candle-examples/examples/yolo-v8/model.rs @@ -1,7 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{ - batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder, -}; +use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder}; #[derive(Clone, Copy, PartialEq, Debug)] pub struct Multiples { @@ -76,7 +74,6 @@ impl Module for Upsample { #[derive(Debug)] struct ConvBlock { conv: Conv2d, - bn: BatchNorm, span: tracing::Span, } @@ -96,11 +93,10 @@ impl ConvBlock { groups: 1, dilation: 1, }; - let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; + let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?; Ok(Self { conv, - bn, span: tracing::span!(tracing::Level::TRACE, "conv-block"), }) } @@ -110,7 +106,6 @@ impl Module for ConvBlock { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let xs = self.conv.forward(xs)?; - let xs = self.bn.forward(&xs)?; candle_nn::ops::silu(&xs) } } diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 05904859..8cfc6740 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -109,6 +109,10 @@ impl BatchNorm { &self.running_var } + pub fn eps(&self) -> f64 { + self.eps + } + pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> { self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1)) } diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 89e9f42d..7c0bf841 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -1,4 +1,5 @@ //! Convolution Layers. +use crate::BatchNorm; use candle::{Result, Tensor}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -115,6 +116,26 @@ impl Conv2d { pub fn bias(&self) -> Option<&Tensor> { self.bias.as_ref() } + + pub fn absorb_bn(&self, bn: &BatchNorm) -> Result<Self> { + if let Some((w_bn, b_bn)) = bn.weight_and_bias() { + let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?; + let weight = self + .weight() + .broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?; + let bias = match &self.bias { + None => b_bn.sub(&(std_.mul(bn.running_mean())?))?, + Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?, + }; + Ok(Self { + weight, + bias: Some(bias), + config: self.config, + }) + } else { + candle::bail!("batch norm does not have weight_and_bias") + } + } } impl crate::Module for Conv2d { |