diff options
Diffstat (limited to 'candle-nn/src/conv.rs')
-rw-r--r-- | candle-nn/src/conv.rs | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 89e9f42d..7c0bf841 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -1,4 +1,5 @@ //! Convolution Layers. +use crate::BatchNorm; use candle::{Result, Tensor}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -115,6 +116,26 @@ impl Conv2d { pub fn bias(&self) -> Option<&Tensor> { self.bias.as_ref() } + + pub fn absorb_bn(&self, bn: &BatchNorm) -> Result<Self> { + if let Some((w_bn, b_bn)) = bn.weight_and_bias() { + let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?; + let weight = self + .weight() + .broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?; + let bias = match &self.bias { + None => b_bn.sub(&(std_.mul(bn.running_mean())?))?, + Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?, + }; + Ok(Self { + weight, + bias: Some(bias), + config: self.config, + }) + } else { + candle::bail!("batch norm does not have weight_and_bias") + } + } } impl crate::Module for Conv2d { |