diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 2 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 33 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 30 |
3 files changed, 62 insertions, 3 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index f4f90373..c6d55e61 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -256,7 +256,7 @@ impl Tensor { // we scale the gradient for this case). let node_upsampled = node.upsample_nearest2d(h, w)?; let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?; - let avg = mask.avg_pool2d(*kernel_size, *stride)?; + let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?; let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index fa85f6e0..a0347416 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -91,3 +91,36 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; + +pub trait ToUsize2 { + fn to_usize2(self) -> (usize, usize); +} + +impl ToUsize2 for usize { + fn to_usize2(self) -> (usize, usize) { + (self, self) + } +} + +impl ToUsize2 for (usize, usize) { + fn to_usize2(self) -> (usize, usize) { + self + } +} + +// A simple trait defining a module with forward method using a single argument. +pub trait Module: std::fmt::Debug { + fn forward(&self, xs: &Tensor) -> Result<Tensor>; + + /// Change the module to use training mode vs eval mode. + /// + /// The default implementation does nothing as this is only used for a couple modules such as + /// dropout or batch-normalization. + fn set_training(&mut self, _training: bool) {} +} + +impl Module for quantized::QMatMul { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self.forward(xs) + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 75b3743d..f834e040 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -797,7 +797,18 @@ impl Tensor { Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) } - pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> { + pub fn avg_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> { + let sz = sz.to_usize2(); + self.avg_pool2d_with_stride(sz, sz) + } + + pub fn avg_pool2d_with_stride<T: crate::ToUsize2>( + &self, + kernel_size: T, + stride: T, + ) -> Result<Self> { + let kernel_size = kernel_size.to_usize2(); + let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d let h_out = (h - kernel_size.0) / stride.0 + 1; @@ -813,7 +824,18 @@ impl Tensor { Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) } - pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> { + pub fn max_pool2d<T: crate::ToUsize2>(&self, sz: T) -> Result<Self> { + let sz = sz.to_usize2(); + self.max_pool2d_with_stride(sz, sz) + } + + pub fn max_pool2d_with_stride<T: crate::ToUsize2>( + &self, + kernel_size: T, + stride: T, + ) -> Result<Self> { + let kernel_size = kernel_size.to_usize2(); + let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d let h_out = (h - kernel_size.0) / stride.0 + 1; @@ -1855,6 +1877,10 @@ impl Tensor { } } + pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> { + m.forward(self) + } + pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } |