diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-29 07:53:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-29 07:53:09 +0100 |
commit | 55bc3382cfd3a86018c54f2343567f7c0c0b677c (patch) | |
tree | 5a5c4ad535d4b5b16a01bdeb3216bbb72bafde28 /candle-nn | |
parent | dece37c6f4d9c5a52caf59a003afa6ba33034fe3 (diff) | |
download | candle-55bc3382cfd3a86018c54f2343567f7c0c0b677c.tar.gz candle-55bc3382cfd3a86018c54f2343567f7c0c0b677c.tar.bz2 candle-55bc3382cfd3a86018c54f2343567f7c0c0b677c.zip |
Allow for different behavior between training and eval (#1213)
* Forward with training.
* Do not use dropout on vgg evaluation.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/func.rs | 35 | ||||
-rw-r--r-- | candle-nn/src/lib.rs | 4 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 6 |
3 files changed, 43 insertions, 2 deletions
diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs index 39311d45..3adfda86 100644 --- a/candle-nn/src/func.rs +++ b/candle-nn/src/func.rs @@ -36,3 +36,38 @@ impl<'a> Func<'a> { Self { f: Arc::new(f) } } } + +/// A layer defined by a simple closure. +#[derive(Clone)] +pub struct FuncT<'a> { + #[allow(clippy::type_complexity)] + f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>, +} + +impl<'a> std::fmt::Debug for FuncT<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "func") + } +} + +pub fn func_t<'a, F>(f: F) -> FuncT<'a> +where + F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync, +{ + FuncT { f: Arc::new(f) } +} + +impl<'a> super::ModuleT for FuncT<'a> { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> { + (*self.f)(xs, train) + } +} + +impl<'a> FuncT<'a> { + pub fn new<F>(f: F) -> Self + where + F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync, + { + Self { f: Arc::new(f) } + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index be95f531..52d8f0c5 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -22,7 +22,7 @@ pub use conv::{ Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig, }; pub use embedding::{embedding, Embedding}; -pub use func::{func, Func}; +pub use func::{func, func_t, Func, FuncT}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; @@ -34,4 +34,4 @@ pub use sequential::{seq, Sequential}; pub use var_builder::VarBuilder; pub use var_map::VarMap; -pub use candle::Module; +pub use candle::{Module, ModuleT}; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 32de1af9..e9812108 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -84,6 +84,12 @@ impl Dropout { } } +impl candle::ModuleT for Dropout { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> { + self.forward(xs, train) + } +} + struct SoftmaxLastDim; impl candle::CustomOp1 for SoftmaxLastDim { |