diff options
Diffstat (limited to 'candle-transformers/src/models/convmixer.rs')
-rw-r--r-- | candle-transformers/src/models/convmixer.rs | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-transformers/src/models/convmixer.rs b/candle-transformers/src/models/convmixer.rs index 76245f37..f5abfa5d 100644 --- a/candle-transformers/src/models/convmixer.rs +++ b/candle-transformers/src/models/convmixer.rs @@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module> let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?; let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?; Ok(candle_nn::func(move |xs| { - let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?; - (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2) + let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?; + (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false) })) } @@ -64,7 +64,7 @@ fn convmixer( .collect::<Result<Vec<_>>>()?; let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?; Ok(candle_nn::func(move |xs| { - let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?; + let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?; for block in blocks.iter() { xs = xs.apply(block)? } |