summaryrefslogtreecommitdiff
path: root/candle-examples/examples/yolo-v8/model.rs
diff options
context:
space:
mode:
authorjamjamjon <51357717+jamjamjon@users.noreply.github.com>2023-10-27 22:56:50 +0800
committerGitHub <noreply@github.com>2023-10-27 15:56:50 +0100
commitb3181455d5bbebdcc358a48fd4d1e5ed38d78198 (patch)
treec1e59be9d5c0a909c32c25053bbe3e926554c697 /candle-examples/examples/yolo-v8/model.rs
parente2826e70b3725c53656f1ff76753472b29e1c5f7 (diff)
downloadcandle-b3181455d5bbebdcc358a48fd4d1e5ed38d78198.tar.gz
candle-b3181455d5bbebdcc358a48fd4d1e5ed38d78198.tar.bz2
candle-b3181455d5bbebdcc358a48fd4d1e5ed38d78198.zip
Add fuse-conv-bn method for Conv2d (#1196)
* Add fuse-conv-bn method for Conv2d * no unwrap * run rustfmp and clippy
Diffstat (limited to 'candle-examples/examples/yolo-v8/model.rs')
-rw-r--r--candle-examples/examples/yolo-v8/model.rs9
1 files changed, 2 insertions, 7 deletions
diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs
index bf48fd84..cecd4ce6 100644
--- a/candle-examples/examples/yolo-v8/model.rs
+++ b/candle-examples/examples/yolo-v8/model.rs
@@ -1,7 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{
- batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
-};
+use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder};
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct Multiples {
@@ -76,7 +74,6 @@ impl Module for Upsample {
#[derive(Debug)]
struct ConvBlock {
conv: Conv2d,
- bn: BatchNorm,
span: tracing::Span,
}
@@ -96,11 +93,10 @@ impl ConvBlock {
groups: 1,
dilation: 1,
};
- let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
+ let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
Ok(Self {
conv,
- bn,
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
})
}
@@ -110,7 +106,6 @@ impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.conv.forward(xs)?;
- let xs = self.bn.forward(&xs)?;
candle_nn::ops::silu(&xs)
}
}