summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs30
1 files changed, 28 insertions, 2 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 75b3743d..f834e040 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -797,7 +797,18 @@ impl Tensor {
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
- pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
+ let sz = sz.to_usize2();
+ self.avg_pool2d_with_stride(sz, sz)
+ }
+
+ pub fn avg_pool2d_with_stride<T: crate::ToUsize2>(
+ &self,
+ kernel_size: T,
+ stride: T,
+ ) -> Result<Self> {
+ let kernel_size = kernel_size.to_usize2();
+ let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
@@ -813,7 +824,18 @@ impl Tensor {
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
- pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> {
+ let sz = sz.to_usize2();
+ self.max_pool2d_with_stride(sz, sz)
+ }
+
+ pub fn max_pool2d_with_stride<T: crate::ToUsize2>(
+ &self,
+ kernel_size: T,
+ stride: T,
+ ) -> Result<Self> {
+ let kernel_size = kernel_size.to_usize2();
+ let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
@@ -1855,6 +1877,10 @@ impl Tensor {
}
}
+ pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
+ m.forward(self)
+ }
+
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}