summaryrefslogtreecommitdiff
path: root/candle-core/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/lib.rs')
-rw-r--r--candle-core/src/lib.rs33
1 files changed, 33 insertions, 0 deletions
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index fa85f6e0..a0347416 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -91,3 +91,36 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
+
+pub trait ToUsize2 {
+ fn to_usize2(self) -> (usize, usize);
+}
+
+impl ToUsize2 for usize {
+ fn to_usize2(self) -> (usize, usize) {
+ (self, self)
+ }
+}
+
+impl ToUsize2 for (usize, usize) {
+ fn to_usize2(self) -> (usize, usize) {
+ self
+ }
+}
+
+// A simple trait defining a module with forward method using a single argument.
+pub trait Module: std::fmt::Debug {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor>;
+
+ /// Change the module to use training mode vs eval mode.
+ ///
+ /// The default implementation does nothing as this is only used for a couple modules such as
+ /// dropout or batch-normalization.
+ fn set_training(&mut self, _training: bool) {}
+}
+
+impl Module for quantized::QMatMul {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.forward(xs)
+ }
+}