diff options
Diffstat (limited to 'candle-core/src/lib.rs')
-rw-r--r-- | candle-core/src/lib.rs | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index fa85f6e0..a0347416 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -91,3 +91,36 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; + +pub trait ToUsize2 { + fn to_usize2(self) -> (usize, usize); +} + +impl ToUsize2 for usize { + fn to_usize2(self) -> (usize, usize) { + (self, self) + } +} + +impl ToUsize2 for (usize, usize) { + fn to_usize2(self) -> (usize, usize) { + self + } +} + +// A simple trait defining a module with forward method using a single argument. +pub trait Module: std::fmt::Debug { + fn forward(&self, xs: &Tensor) -> Result<Tensor>; + + /// Change the module to use training mode vs eval mode. + /// + /// The default implementation does nothing as this is only used for a couple modules such as + /// dropout or batch-normalization. + fn set_training(&mut self, _training: bool) {} +} + +impl Module for quantized::QMatMul { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self.forward(xs) + } +} |