diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-29 19:12:16 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-29 19:12:16 +0100 |
commit | 2d3fcad26788dff3fa73996a3cc8e5fd5382f6b2 (patch) | |
tree | d2ab2b3a5b0a08903123a9039319f30e9fe6cc07 /candle-core/src/tensor.rs | |
parent | b31d41e26a47d91d828e0c4f567f14b659775e5e (diff) | |
download | candle-2d3fcad26788dff3fa73996a3cc8e5fd5382f6b2.tar.gz candle-2d3fcad26788dff3fa73996a3cc8e5fd5382f6b2.tar.bz2 candle-2d3fcad26788dff3fa73996a3cc8e5fd5382f6b2.zip |
Simplify usage of the pool functions. (#662)
* Simplify usage of the pool functions.
* Small tweak.
* Attempt at using apply to simplify the convnet definition.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 75b3743d..f834e040 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -797,7 +797,18 @@ impl Tensor { Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) } - pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> { + pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> { + let sz = sz.to_usize2(); + self.avg_pool2d_with_stride(sz, sz) + } + + pub fn avg_pool2d_with_stride<T: crate::ToUsize2>( + &self, + kernel_size: T, + stride: T, + ) -> Result<Self> { + let kernel_size = kernel_size.to_usize2(); + let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d let h_out = (h - kernel_size.0) / stride.0 + 1; @@ -813,7 +824,18 @@ impl Tensor { Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) } - pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> { + pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> { + let sz = sz.to_usize2(); + self.max_pool2d_with_stride(sz, sz) + } + + pub fn max_pool2d_with_stride<T: crate::ToUsize2>( + &self, + kernel_size: T, + stride: T, + ) -> Result<Self> { + let kernel_size = kernel_size.to_usize2(); + let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d let h_out = (h - kernel_size.0) / stride.0 + 1; @@ -1855,6 +1877,10 @@ impl Tensor { } } + pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> { + m.forward(self) + } + pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } |