summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backprop.rs2
-rw-r--r--candle-core/src/lib.rs33
-rw-r--r--candle-core/src/tensor.rs30
-rw-r--r--candle-core/tests/pool_tests.rs14
-rw-r--r--candle-examples/examples/mnist-training/main.rs16
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs2
-rw-r--r--candle-examples/examples/yolo-v8/model.rs6
-rw-r--r--candle-nn/src/lib.rs19
-rw-r--r--candle-wasm-examples/yolo/src/model.rs6
9 files changed, 86 insertions, 42 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index f4f90373..c6d55e61 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -256,7 +256,7 @@ impl Tensor {
// we scale the gradient for this case).
let node_upsampled = node.upsample_nearest2d(h, w)?;
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
- let avg = mask.avg_pool2d(*kernel_size, *stride)?;
+ let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index fa85f6e0..a0347416 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -91,3 +91,36 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
+
+pub trait ToUsize2 {
+ fn to_usize2(self) -> (usize, usize);
+}
+
+impl ToUsize2 for usize {
+ fn to_usize2(self) -> (usize, usize) {
+ (self, self)
+ }
+}
+
+impl ToUsize2 for (usize, usize) {
+ fn to_usize2(self) -> (usize, usize) {
+ self
+ }
+}
+
+// A simple trait defining a module with forward method using a single argument.
+pub trait Module: std::fmt::Debug {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor>;
+
+ /// Change the module to use training mode vs eval mode.
+ ///
+ /// The default implementation does nothing as this is only used for a couple modules such as
+ /// dropout or batch-normalization.
+ fn set_training(&mut self, _training: bool) {}
+}
+
+impl Module for quantized::QMatMul {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.forward(xs)
+ }
+}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 75b3743d..f834e040 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -797,7 +797,18 @@ impl Tensor {
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
- pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
+ let sz = sz.to_usize2();
+ self.avg_pool2d_with_stride(sz, sz)
+ }
+
+ pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
+ &self,
+ kernel_size: T,
+ stride: T,
+ ) -> Result<Self> {
+ let kernel_size = kernel_size.to_usize2();
+ let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
@@ -813,7 +824,18 @@ impl Tensor {
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
- pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
+ let sz = sz.to_usize2();
+ self.max_pool2d_with_stride(sz, sz)
+ }
+
+ pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
+ &self,
+ kernel_size: T,
+ stride: T,
+ ) -> Result<Self> {
+ let kernel_size = kernel_size.to_usize2();
+ let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
@@ -1855,6 +1877,10 @@ impl Tensor {
}
}
+ pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
+ m.forward(self)
+ }
+
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
index b8c007b8..c6db194d 100644
--- a/candle-core/tests/pool_tests.rs
+++ b/candle-core/tests/pool_tests.rs
@@ -6,14 +6,14 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
- let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?;
- let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]);
Ok(())
}
@@ -24,11 +24,11 @@ fn max_pool2d(dev: &Device) -> Result<()> {
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
- let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
let t = t.reshape((1, 1, 2, 8))?;
- let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[2.0, 3.0, 5.0, 1.0]]);
Ok(())
}
@@ -53,7 +53,7 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
dev,
)?
.reshape((1, 2, 4, 4))?;
- let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
+ let pool = t.avg_pool2d(2)?.squeeze(0)?;
assert_eq!(
test_utils::to_vec3_round(&pool, 4)?,
[
@@ -61,14 +61,14 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
[[0.1835, -0.1606], [0.6249, 0.3217]]
]
);
- let pool = t.avg_pool2d((3, 3), (3, 3))?.squeeze(0)?;
+ let pool = t.avg_pool2d(3)?.squeeze(0)?;
assert_eq!(
test_utils::to_vec3_round(&pool, 4)?,
[[[0.085]], [[0.0078]]]
);
let t = t.reshape((1, 1, 4, 8))?;
- let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?;
assert_eq!(
test_utils::to_vec2_round(&pool, 4)?,
[
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs
index 5bbce31b..a90904c4 100644
--- a/candle-examples/examples/mnist-training/main.rs
+++ b/candle-examples/examples/mnist-training/main.rs
@@ -83,13 +83,15 @@ impl Model for ConvNet {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b_sz, _img_dim) = xs.dims2()?;
- let xs = xs.reshape((b_sz, 1, 28, 28))?;
- let xs = self.conv1.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
- let xs = self.conv2.forward(&xs)?.max_pool2d((2, 2), (2, 2))?;
- let xs = xs.flatten_from(1)?;
- let xs = self.fc1.forward(&xs)?;
- let xs = xs.relu()?;
- self.fc2.forward(&xs)
+ xs.reshape((b_sz, 1, 28, 28))?
+ .apply(&self.conv1)?
+ .max_pool2d(2)?
+ .apply(&self.conv2)?
+ .max_pool2d(2)?
+ .flatten_from(1)?
+ .apply(&self.fc1)?
+ .relu()?
+ .apply(&self.fc2)
}
}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 1db65222..26a1035b 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -47,7 +47,7 @@ impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
match &self.conv {
- None => xs.avg_pool2d((2, 2), (2, 2)),
+ None => xs.avg_pool2d(2),
Some(conv) => {
if self.padding == 0 {
let xs = xs
diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs
index d7fe5c12..b834f967 100644
--- a/candle-examples/examples/yolo-v8/model.rs
+++ b/candle-examples/examples/yolo-v8/model.rs
@@ -198,15 +198,15 @@ impl Module for Sppf {
let xs2 = xs
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
- .max_pool2d((self.k, self.k), (1, 1))?;
+ .max_pool2d_with_stride(self.k, 1)?;
let xs3 = xs2
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
- .max_pool2d((self.k, self.k), (1, 1))?;
+ .max_pool2d_with_stride(self.k, 1)?;
let xs4 = xs3
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
- .max_pool2d((self.k, self.k), (1, 1))?;
+ .max_pool2d_with_stride(self.k, 1)?;
self.cv2.forward(&Tensor::cat(&[&xs, &xs2, &xs3, &xs4], 1)?)
}
}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 34e2dbed..2e2c2545 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -1,5 +1,3 @@
-use candle::{Result, Tensor};
-
pub mod activation;
pub mod batch_norm;
pub mod conv;
@@ -28,19 +26,4 @@ pub use optim::{AdamW, ParamsAdamW, SGD};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;
-// A simple trait defining a module with forward method using a single argument.
-pub trait Module: std::fmt::Debug {
- fn forward(&self, xs: &Tensor) -> Result<Tensor>;
-
- /// Change the module to use training mode vs eval mode.
- ///
- /// The default implementation does nothing as this is only used for a couple modules such as
- /// dropout or batch-normalization.
- fn set_training(&mut self, _training: bool) {}
-}
-
-impl Module for candle::quantized::QMatMul {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- self.forward(xs)
- }
-}
+pub use candle::Module;
diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs
index e0fa7ac4..d49cf55f 100644
--- a/candle-wasm-examples/yolo/src/model.rs
+++ b/candle-wasm-examples/yolo/src/model.rs
@@ -202,15 +202,15 @@ impl Module for Sppf {
let xs2 = xs
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
- .max_pool2d((self.k, self.k), (1, 1))?;
+ .max_pool2d_with_stride(self.k, 1)?;
let xs3 = xs2
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
- .max_pool2d((self.k, self.k), (1, 1))?;
+ .max_pool2d_with_stride(self.k, 1)?;
let xs4 = xs3
.pad_with_zeros(2, self.k / 2, self.k / 2)?
.pad_with_zeros(3, self.k / 2, self.k / 2)?
- .max_pool2d((self.k, self.k), (1, 1))?;
+ .max_pool2d_with_stride(self.k, 1)?;
self.cv2.forward(&Tensor::cat(&[&xs, &xs2, &xs3, &xs4], 1)?)
}
}