diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-20 16:08:50 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-20 16:08:50 +0100 |
commit | 99cf13e8e2f15b700c052ff8ec7b20f42badd96a (patch) | |
tree | 0e38c909a93d683c3ec72b378c904a2920571a5f /candle-nn | |
parent | b43ab6cd1d7b128f2f9d7d8d3acc3a29c9d3b289 (diff) | |
download | candle-99cf13e8e2f15b700c052ff8ec7b20f42badd96a.tar.gz candle-99cf13e8e2f15b700c052ff8ec7b20f42badd96a.tar.bz2 candle-99cf13e8e2f15b700c052ff8ec7b20f42badd96a.zip |
Add the sequential layer. (#1136)
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/sequential.rs | 62 |
2 files changed, 64 insertions, 0 deletions
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 8e5580df..be95f531 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -11,6 +11,7 @@ pub mod loss; pub mod ops; pub mod optim; pub mod rnn; +pub mod sequential; pub mod var_builder; pub mod var_map; @@ -29,6 +30,7 @@ pub use linear::{linear, linear_no_bias, Linear}; pub use ops::Dropout; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; +pub use sequential::{seq, Sequential}; pub use var_builder::VarBuilder; pub use var_map::VarMap; diff --git a/candle-nn/src/sequential.rs b/candle-nn/src/sequential.rs new file mode 100644 index 00000000..2fef7742 --- /dev/null +++ b/candle-nn/src/sequential.rs @@ -0,0 +1,62 @@ +//! A sequential layer used to chain multiple layers and closures. +use candle::{Module, Result, Tensor}; + +/// A sequential layer combining multiple other layers. +pub struct Sequential { + layers: Vec<Box<dyn Module>>, +} + +/// Creates a new empty sequential layer. +pub fn seq() -> Sequential { + Sequential { layers: vec![] } +} + +impl Sequential { + /// The number of sub-layers embedded in this layer. + pub fn len(&self) -> i64 { + self.layers.len() as i64 + } + + /// Returns true if this layer does not have any sub-layer. + pub fn is_empty(&self) -> bool { + self.layers.is_empty() + } +} + +impl Module for Sequential { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs)? + } + Ok(xs) + } +} + +impl Sequential { + /// Appends a layer after all the current layers. + #[allow(clippy::should_implement_trait)] + pub fn add<M: Module + 'static>(mut self, layer: M) -> Self { + self.layers.push(Box::new(layer)); + self + } + + /// Appends a closure after all the current layers. + pub fn add_fn<F>(self, f: F) -> Self + where + F: 'static + Fn(&Tensor) -> Result<Tensor> + Send, + { + self.add(super::func(f)) + } + + /// Applies the forward pass and returns the output for each layer. + pub fn forward_all(&self, xs: &Tensor) -> Result<Vec<Tensor>> { + let mut vec = Vec::with_capacity(self.layers.len()); + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs)?; + vec.push(xs.clone()) + } + Ok(vec) + } +} |