diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-20 11:33:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-20 11:33:51 +0100 |
commit | 7b1ddcff47b14f80c3d11e65aa9b452ecc17d59a (patch) | |
tree | fbf3be7a7646818d90b9bd90dab6999f22eca6ff /candle-nn/src | |
parent | f685b2231cad71e98257b62618b364c255feacd7 (diff) | |
download | candle-7b1ddcff47b14f80c3d11e65aa9b452ecc17d59a.tar.gz candle-7b1ddcff47b14f80c3d11e65aa9b452ecc17d59a.tar.bz2 candle-7b1ddcff47b14f80c3d11e65aa9b452ecc17d59a.zip |
Add clone to various nn layers. (#910)
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/batch_norm.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/conv.rs | 6 | ||||
-rw-r--r-- | candle-nn/src/embedding.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/group_norm.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/layer_norm.rs | 4 | ||||
-rw-r--r-- | candle-nn/src/linear.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/rnn.rs | 4 |
7 files changed, 11 insertions, 11 deletions
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 2dac0758..27ef15f7 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -38,7 +38,7 @@ impl From<f64> for BatchNormConfig { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct BatchNorm { running_mean: Tensor, running_var: Tensor, diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 31bf9af0..89e9f42d 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -20,7 +20,7 @@ impl Default for Conv1dConfig { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Conv1d { weight: Tensor, bias: Option<Tensor>, @@ -88,7 +88,7 @@ impl Default for Conv2dConfig { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Conv2d { weight: Tensor, bias: Option<Tensor>, @@ -157,7 +157,7 @@ impl Default for ConvTranspose2dConfig { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct ConvTranspose2d { weight: Tensor, bias: Option<Tensor>, diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index fccc8a17..52968bc2 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -1,7 +1,7 @@ //! Embedding Layer. use candle::{Result, Tensor}; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Embedding { embeddings: Tensor, hidden_size: usize, diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs index eb1b889f..5b80b970 100644 --- a/candle-nn/src/group_norm.rs +++ b/candle-nn/src/group_norm.rs @@ -4,7 +4,7 @@ use candle::{DType, Result, Tensor}; // This group norm version handles both weight and bias so removes the mean. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct GroupNorm { weight: Tensor, bias: Tensor, diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index d2e80a82..7617fc6c 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -60,7 +60,7 @@ impl From<f64> for LayerNormConfig { } // This layer norm version handles both weight and bias so removes the mean. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct LayerNorm { weight: Tensor, bias: Option<Tensor>, @@ -143,7 +143,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>( } /// RmsNorm is a specialized version of the LayerNorm module. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct RmsNorm(LayerNorm); impl RmsNorm { diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index de335964..94632296 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -19,7 +19,7 @@ //! ``` use candle::{Result, Tensor}; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Linear { weight: Tensor, bias: Option<Tensor>, diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index d52a9082..18a4a71c 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -85,7 +85,7 @@ impl LSTMConfig { /// /// <https://en.wikipedia.org/wiki/Long_short-term_memory> #[allow(clippy::upper_case_acronyms, unused)] -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct LSTM { w_ih: Tensor, w_hh: Tensor, @@ -235,7 +235,7 @@ impl GRUConfig { /// /// <https://en.wikipedia.org/wiki/Gated_recurrent_unit> #[allow(clippy::upper_case_acronyms, unused)] -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct GRU { w_ih: Tensor, w_hh: Tensor, |