summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/mnist-training/main.rs16
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs2
-rw-r--r--candle-examples/examples/yolo-v8/model.rs6
3 files changed, 13 insertions, 11 deletions
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs
index 5bbce31b..a90904c4 100644
--- a/candle-examples/examples/mnist-training/main.rs
+++ b/candle-examples/examples/mnist-training/main.rs
@@ -83,13 +83,15 @@ impl Model for ConvNet {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b_sz, _img_dim) = xs.dims2()?;
- let xs = xs.reshape((b_sz, 1, 28, 28))?;
- let xs = self.conv1.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
- let xs = self.conv2.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
- let xs = xs.flatten_from(1)?;
- let xs = self.fc1.forward(&xs)?;
- let xs = xs.relu()?;
- self.fc2.forward(&xs)
+ xs.reshape((b_sz, 1, 28, 28))?
+ .apply(&self.conv1)?
+ .max_pool2d(2)?
+ .apply(&self.conv2)?
+ .max_pool2d(2)?
+ .flatten_from(1)?
+ .apply(&self.fc1)?
+ .relu()?
+ .apply(&self.fc2)
}
}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 1db65222..26a1035b 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -47,7 +47,7 @@ impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
match &self.conv {
- None => xs.avg_pool2d((2, 2), (2, 2)),
+ None => xs.avg_pool2d(2),
Some(conv) => {
if self.padding == 0 {
let xs = xs
diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs
index d7fe5c12..b834f967 100644
--- a/candle-examples/examples/yolo-v8/model.rs
+++ b/candle-examples/examples/yolo-v8/model.rs
@@ -198,15 +198,15 @@ impl Module for Sppf {
let xs2 = xs
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
- .max_pool2d((self.k, self.k), (1, 1))?;
+ .max_pool2d_with_stride(self.k, 1)?;
let xs3 = xs2
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
- .max_pool2d((self.k, self.k), (1, 1))?;
+ .max_pool2d_with_stride(self.k, 1)?;
let xs4 = xs3
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
- .max_pool2d((self.k, self.k), (1, 1))?;
+ .max_pool2d_with_stride(self.k, 1)?;
self.cv2.forward(&Tensor::cat(&[&xs, &xs2, &xs3, &xs4], 1)?)
}
}