diff options
Diffstat (limited to 'candle-transformers/src/models/resnet.rs')
-rw-r--r-- | candle-transformers/src/models/resnet.rs | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/candle-transformers/src/models/resnet.rs b/candle-transformers/src/models/resnet.rs index f2588e01..30029a0b 100644 --- a/candle-transformers/src/models/resnet.rs +++ b/candle-transformers/src/models/resnet.rs @@ -25,7 +25,7 @@ fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resul if stride != 1 || c_in != c_out { let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?; let bn = batch_norm(c_out, 1e-5, vb.pp(1))?; - Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn))) + Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false))) } else { Ok(Func::new(|xs| Ok(xs.clone()))) } @@ -40,10 +40,10 @@ fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resu Ok(Func::new(move |xs| { let ys = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .apply(&conv2)? - .apply(&bn2)?; + .apply_t(&bn2, false)?; (xs.apply(&downsample)? + ys)?.relu() })) } @@ -94,7 +94,7 @@ fn resnet( Ok(Func::new(move |xs| { let xs = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .pad_with_same(D::Minus1, 1, 1)? .pad_with_same(D::Minus2, 1, 1)? @@ -149,13 +149,13 @@ fn bottleneck_block( Ok(Func::new(move |xs| { let ys = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .apply(&conv2)? - .apply(&bn2)? + .apply_t(&bn2, false)? .relu()? .apply(&conv3)? - .apply(&bn3)?; + .apply_t(&bn3, false)?; (xs.apply(&downsample)? + ys)?.relu() })) } @@ -206,7 +206,7 @@ fn bottleneck_resnet( Ok(Func::new(move |xs| { let xs = xs .apply(&conv1)? - .apply(&bn1)? + .apply_t(&bn1, false)? .relu()? .pad_with_same(D::Minus1, 1, 1)? .pad_with_same(D::Minus2, 1, 1)? |