diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/mnist-training/main.rs | 16 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/yolo-v8/model.rs | 6 |
3 files changed, 13 insertions, 11 deletions
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index 5bbce31b..a90904c4 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -83,13 +83,15 @@ impl Model for ConvNet { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let (b_sz, _img_dim) = xs.dims2()?; - let xs = xs.reshape((b_sz, 1, 28, 28))?; - let xs = self.conv1.forward(&xs)?.max_pool2d((2, 2), (2, 2))?; - let xs = self.conv2.forward(&xs)?.max_pool2d((2, 2), (2, 2))?; - let xs = xs.flatten_from(1)?; - let xs = self.fc1.forward(&xs)?; - let xs = xs.relu()?; - self.fc2.forward(&xs) + xs.reshape((b_sz, 1, 28, 28))? + .apply(&self.conv1)? + .max_pool2d(2)? + .apply(&self.conv2)? + .max_pool2d(2)? + .flatten_from(1)? + .apply(&self.fc1)? + .relu()? + .apply(&self.fc2) } } diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index 1db65222..26a1035b 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -47,7 +47,7 @@ impl Downsample2D { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); match &self.conv { - None => xs.avg_pool2d((2, 2), (2, 2)), + None => xs.avg_pool2d(2), Some(conv) => { if self.padding == 0 { let xs = xs diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs index d7fe5c12..b834f967 100644 --- a/candle-examples/examples/yolo-v8/model.rs +++ b/candle-examples/examples/yolo-v8/model.rs @@ -198,15 +198,15 @@ impl Module for Sppf { let xs2 = xs .pad_with_zeros(2, self.k / 2, self.k / 2)? .pad_with_zeros(3, self.k / 2, self.k / 2)? - .max_pool2d((self.k, self.k), (1, 1))?; + .max_pool2d_with_stride(self.k, 1)?; let xs3 = xs2 .pad_with_zeros(2, self.k / 2, self.k / 2)? .pad_with_zeros(3, self.k / 2, self.k / 2)? - .max_pool2d((self.k, self.k), (1, 1))?; + .max_pool2d_with_stride(self.k, 1)?; let xs4 = xs3 .pad_with_zeros(2, self.k / 2, self.k / 2)? .pad_with_zeros(3, self.k / 2, self.k / 2)? - .max_pool2d((self.k, self.k), (1, 1))?; + .max_pool2d_with_stride(self.k, 1)?; self.cv2.forward(&Tensor::cat(&[&xs, &xs2, &xs3, &xs4], 1)?) } } |