summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/yolo-v8/model.rs9
-rw-r--r--candle-nn/src/batch_norm.rs4
-rw-r--r--candle-nn/src/conv.rs21
3 files changed, 27 insertions, 7 deletions
diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs
index bf48fd84..cecd4ce6 100644
--- a/candle-examples/examples/yolo-v8/model.rs
+++ b/candle-examples/examples/yolo-v8/model.rs
@@ -1,7 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{
- batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
-};
+use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder};
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct Multiples {
@@ -76,7 +74,6 @@ impl Module for Upsample {
#[derive(Debug)]
struct ConvBlock {
conv: Conv2d,
- bn: BatchNorm,
span: tracing::Span,
}
@@ -96,11 +93,10 @@ impl ConvBlock {
groups: 1,
dilation: 1,
};
- let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
+ let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
Ok(Self {
conv,
- bn,
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
})
}
@@ -110,7 +106,6 @@ impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.conv.forward(xs)?;
- let xs = self.bn.forward(&xs)?;
candle_nn::ops::silu(&xs)
}
}
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs
index 05904859..8cfc6740 100644
--- a/candle-nn/src/batch_norm.rs
+++ b/candle-nn/src/batch_norm.rs
@@ -109,6 +109,10 @@ impl BatchNorm {
&self.running_var
}
+ pub fn eps(&self) -> f64 {
+ self.eps
+ }
+
pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
}
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index 89e9f42d..7c0bf841 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -1,4 +1,5 @@
//! Convolution Layers.
+use crate::BatchNorm;
use candle::{Result, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -115,6 +116,26 @@ impl Conv2d {
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
+
+ pub fn absorb_bn(&self, bn: &BatchNorm) -> Result<Self> {
+ if let Some((w_bn, b_bn)) = bn.weight_and_bias() {
+ let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?;
+ let weight = self
+ .weight()
+ .broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?;
+ let bias = match &self.bias {
+ None => b_bn.sub(&(std_.mul(bn.running_mean())?))?,
+ Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?,
+ };
+ Ok(Self {
+ weight,
+ bias: Some(bias),
+ config: self.config,
+ })
+ } else {
+ candle::bail!("batch norm does not have weight_and_bias")
+ }
+ }
}
impl crate::Module for Conv2d {