diff options
author | jamjamjon <51357717+jamjamjon@users.noreply.github.com> | 2023-10-27 22:56:50 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-27 15:56:50 +0100 |
commit | b3181455d5bbebdcc358a48fd4d1e5ed38d78198 (patch) | |
tree | c1e59be9d5c0a909c32c25053bbe3e926554c697 /candle-examples/examples/yolo-v8/model.rs | |
parent | e2826e70b3725c53656f1ff76753472b29e1c5f7 (diff) | |
download | candle-b3181455d5bbebdcc358a48fd4d1e5ed38d78198.tar.gz candle-b3181455d5bbebdcc358a48fd4d1e5ed38d78198.tar.bz2 candle-b3181455d5bbebdcc358a48fd4d1e5ed38d78198.zip |
Add fuse-conv-bn method for Conv2d (#1196)
* Add fuse-conv-bn method for Conv2d
* no unwrap
* run rustfmp and clippy
Diffstat (limited to 'candle-examples/examples/yolo-v8/model.rs')
-rw-r--r-- | candle-examples/examples/yolo-v8/model.rs | 9 |
1 files changed, 2 insertions, 7 deletions
diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs index bf48fd84..cecd4ce6 100644 --- a/candle-examples/examples/yolo-v8/model.rs +++ b/candle-examples/examples/yolo-v8/model.rs @@ -1,7 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{ - batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder, -}; +use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder}; #[derive(Clone, Copy, PartialEq, Debug)] pub struct Multiples { @@ -76,7 +74,6 @@ impl Module for Upsample { #[derive(Debug)] struct ConvBlock { conv: Conv2d, - bn: BatchNorm, span: tracing::Span, } @@ -96,11 +93,10 @@ impl ConvBlock { groups: 1, dilation: 1, }; - let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; + let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?; Ok(Self { conv, - bn, span: tracing::span!(tracing::Level::TRACE, "conv-block"), }) } @@ -110,7 +106,6 @@ impl Module for ConvBlock { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); let xs = self.conv.forward(xs)?; - let xs = self.bn.forward(&xs)?; candle_nn::ops::silu(&xs) } } |