summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-29 07:53:09 +0100
committerGitHub <noreply@github.com>2023-10-29 07:53:09 +0100
commit55bc3382cfd3a86018c54f2343567f7c0c0b677c (patch)
tree5a5c4ad535d4b5b16a01bdeb3216bbb72bafde28 /candle-nn
parentdece37c6f4d9c5a52caf59a003afa6ba33034fe3 (diff)
downloadcandle-55bc3382cfd3a86018c54f2343567f7c0c0b677c.tar.gz
candle-55bc3382cfd3a86018c54f2343567f7c0c0b677c.tar.bz2
candle-55bc3382cfd3a86018c54f2343567f7c0c0b677c.zip
Allow for different behavior between training and eval (#1213)
* Forward with training. * Do not use dropout on vgg evaluation.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/func.rs35
-rw-r--r--candle-nn/src/lib.rs4
-rw-r--r--candle-nn/src/ops.rs6
3 files changed, 43 insertions, 2 deletions
diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs
index 39311d45..3adfda86 100644
--- a/candle-nn/src/func.rs
+++ b/candle-nn/src/func.rs
@@ -36,3 +36,38 @@ impl<'a> Func<'a> {
Self { f: Arc::new(f) }
}
}
+
+/// A layer defined by a simple closure.
+#[derive(Clone)]
+pub struct FuncT<'a> {
+ #[allow(clippy::type_complexity)]
+ f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
+}
+
+impl<'a> std::fmt::Debug for FuncT<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "func")
+ }
+}
+
+pub fn func_t<'a, F>(f: F) -> FuncT<'a>
+where
+ F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
+{
+ FuncT { f: Arc::new(f) }
+}
+
+impl<'a> super::ModuleT for FuncT<'a> {
+ fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
+ (*self.f)(xs, train)
+ }
+}
+
+impl<'a> FuncT<'a> {
+ pub fn new<F>(f: F) -> Self
+ where
+ F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
+ {
+ Self { f: Arc::new(f) }
+ }
+}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index be95f531..52d8f0c5 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -22,7 +22,7 @@ pub use conv::{
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
};
pub use embedding::{embedding, Embedding};
-pub use func::{func, Func};
+pub use func::{func, func_t, Func, FuncT};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
@@ -34,4 +34,4 @@ pub use sequential::{seq, Sequential};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;
-pub use candle::Module;
+pub use candle::{Module, ModuleT};
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 32de1af9..e9812108 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -84,6 +84,12 @@ impl Dropout {
}
}
+impl candle::ModuleT for Dropout {
+ fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
+ self.forward(xs, train)
+ }
+}
+
struct SoftmaxLastDim;
impl candle::CustomOp1 for SoftmaxLastDim {