summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-20 16:08:50 +0100
committerGitHub <noreply@github.com>2023-10-20 16:08:50 +0100
commit99cf13e8e2f15b700c052ff8ec7b20f42badd96a (patch)
tree0e38c909a93d683c3ec72b378c904a2920571a5f /candle-nn
parentb43ab6cd1d7b128f2f9d7d8d3acc3a29c9d3b289 (diff)
downloadcandle-99cf13e8e2f15b700c052ff8ec7b20f42badd96a.tar.gz
candle-99cf13e8e2f15b700c052ff8ec7b20f42badd96a.tar.bz2
candle-99cf13e8e2f15b700c052ff8ec7b20f42badd96a.zip
Add the sequential layer. (#1136)
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/lib.rs2
-rw-r--r--candle-nn/src/sequential.rs62
2 files changed, 64 insertions, 0 deletions
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 8e5580df..be95f531 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -11,6 +11,7 @@ pub mod loss;
pub mod ops;
pub mod optim;
pub mod rnn;
+pub mod sequential;
pub mod var_builder;
pub mod var_map;
@@ -29,6 +30,7 @@ pub use linear::{linear, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
+pub use sequential::{seq, Sequential};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;
diff --git a/candle-nn/src/sequential.rs b/candle-nn/src/sequential.rs
new file mode 100644
index 00000000..2fef7742
--- /dev/null
+++ b/candle-nn/src/sequential.rs
@@ -0,0 +1,62 @@
+//! A sequential layer used to chain multiple layers and closures.
+use candle::{Module, Result, Tensor};
+
+/// A sequential layer combining multiple other layers.
+pub struct Sequential {
+ layers: Vec<Box<dyn Module>>,
+}
+
+/// Creates a new empty sequential layer.
+pub fn seq() -> Sequential {
+ Sequential { layers: vec![] }
+}
+
+impl Sequential {
+ /// The number of sub-layers embedded in this layer.
+ pub fn len(&self) -> i64 {
+ self.layers.len() as i64
+ }
+
+ /// Returns true if this layer does not have any sub-layer.
+ pub fn is_empty(&self) -> bool {
+ self.layers.is_empty()
+ }
+}
+
+impl Module for Sequential {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs)?
+ }
+ Ok(xs)
+ }
+}
+
+impl Sequential {
+ /// Appends a layer after all the current layers.
+ #[allow(clippy::should_implement_trait)]
+ pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
+ self.layers.push(Box::new(layer));
+ self
+ }
+
+ /// Appends a closure after all the current layers.
+ pub fn add_fn<F>(self, f: F) -> Self
+ where
+ F: 'static + Fn(&Tensor) -> Result<Tensor> + Send,
+ {
+ self.add(super::func(f))
+ }
+
+ /// Applies the forward pass and returns the output for each layer.
+ pub fn forward_all(&self, xs: &Tensor) -> Result<Vec<Tensor>> {
+ let mut vec = Vec::with_capacity(self.layers.len());
+ let mut xs = xs.clone();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs)?;
+ vec.push(xs.clone())
+ }
+ Ok(vec)
+ }
+}