summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/convmixer.rs6
-rw-r--r--candle-transformers/src/models/efficientnet.rs3
-rw-r--r--candle-transformers/src/models/resnet.rs16
-rw-r--r--candle-transformers/src/models/segment_anything/tiny_vit.rs2
-rw-r--r--candle-transformers/src/models/wuerstchen/paella_vq.rs2
5 files changed, 14 insertions, 15 deletions
diff --git a/candle-transformers/src/models/convmixer.rs b/candle-transformers/src/models/convmixer.rs
index 76245f37..f5abfa5d 100644
--- a/candle-transformers/src/models/convmixer.rs
+++ b/candle-transformers/src/models/convmixer.rs
@@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module>
let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
Ok(candle_nn::func(move |xs| {
- let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
- (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2)
+ let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
+ (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false)
}))
}
@@ -64,7 +64,7 @@ fn convmixer(
.collect::<Result<Vec<_>>>()?;
let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
Ok(candle_nn::func(move |xs| {
- let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
+ let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
for block in blocks.iter() {
xs = xs.apply(block)?
}
diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs
index ab51c76d..f15c9c79 100644
--- a/candle-transformers/src/models/efficientnet.rs
+++ b/candle-transformers/src/models/efficientnet.rs
@@ -169,8 +169,7 @@ impl ConvNormActivation {
impl Module for ConvNormActivation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let xs = self.conv2d.forward(xs)?;
- let xs = self.bn2d.forward(&xs)?;
+ let xs = self.conv2d.forward(xs)?.apply_t(&self.bn2d, false)?;
if self.activation {
swish(&xs)
} else {
diff --git a/candle-transformers/src/models/resnet.rs b/candle-transformers/src/models/resnet.rs
index f2588e01..30029a0b 100644
--- a/candle-transformers/src/models/resnet.rs
+++ b/candle-transformers/src/models/resnet.rs
@@ -25,7 +25,7 @@ fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resul
if stride != 1 || c_in != c_out {
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
- Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn)))
+ Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
} else {
Ok(Func::new(|xs| Ok(xs.clone())))
}
@@ -40,10 +40,10 @@ fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resu
Ok(Func::new(move |xs| {
let ys = xs
.apply(&conv1)?
- .apply(&bn1)?
+ .apply_t(&bn1, false)?
.relu()?
.apply(&conv2)?
- .apply(&bn2)?;
+ .apply_t(&bn2, false)?;
(xs.apply(&downsample)? + ys)?.relu()
}))
}
@@ -94,7 +94,7 @@ fn resnet(
Ok(Func::new(move |xs| {
let xs = xs
.apply(&conv1)?
- .apply(&bn1)?
+ .apply_t(&bn1, false)?
.relu()?
.pad_with_same(D::Minus1, 1, 1)?
.pad_with_same(D::Minus2, 1, 1)?
@@ -149,13 +149,13 @@ fn bottleneck_block(
Ok(Func::new(move |xs| {
let ys = xs
.apply(&conv1)?
- .apply(&bn1)?
+ .apply_t(&bn1, false)?
.relu()?
.apply(&conv2)?
- .apply(&bn2)?
+ .apply_t(&bn2, false)?
.relu()?
.apply(&conv3)?
- .apply(&bn3)?;
+ .apply_t(&bn3, false)?;
(xs.apply(&downsample)? + ys)?.relu()
}))
}
@@ -206,7 +206,7 @@ fn bottleneck_resnet(
Ok(Func::new(move |xs| {
let xs = xs
.apply(&conv1)?
- .apply(&bn1)?
+ .apply_t(&bn1, false)?
.relu()?
.pad_with_same(D::Minus1, 1, 1)?
.pad_with_same(D::Minus2, 1, 1)?
diff --git a/candle-transformers/src/models/segment_anything/tiny_vit.rs b/candle-transformers/src/models/segment_anything/tiny_vit.rs
index cd2936ab..d1700cc5 100644
--- a/candle-transformers/src/models/segment_anything/tiny_vit.rs
+++ b/candle-transformers/src/models/segment_anything/tiny_vit.rs
@@ -28,7 +28,7 @@ impl Conv2dBN {
impl Module for Conv2dBN {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
- xs.apply(&self.c)?.apply(&self.bn)
+ xs.apply(&self.c)?.apply_t(&self.bn, false)
}
}
diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs
index 4a69cca0..58f795bb 100644
--- a/candle-transformers/src/models/wuerstchen/paella_vq.rs
+++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs
@@ -185,7 +185,7 @@ impl PaellaVQ {
xs = xs.apply(&down_block.1)?
}
xs.apply(&self.down_blocks_conv)?
- .apply(&self.down_blocks_bn)
+ .apply_t(&self.down_blocks_bn, false)
}
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {