diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-28 08:43:08 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-28 07:43:08 +0100 |
commit | 612f5b81561150ca6651368c245ac2065c04159a (patch) | |
tree | 2c80848778d4d67dce7e7f5803155e5cc44c0f57 /candle-transformers | |
parent | ef33df7ae2b94e2b911b61f3765d6826726614e7 (diff) | |
download | candle-612f5b81561150ca6651368c245ac2065c04159a.tar.gz candle-612f5b81561150ca6651368c245ac2065c04159a.tar.bz2 candle-612f5b81561150ca6651368c245ac2065c04159a.zip |
Make more models cloneable. (#1203)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/quantized_stable_lm.rs | 8 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_t5.rs | 22 | ||||
-rw-r--r-- | candle-transformers/src/models/t5.rs | 22 |
3 files changed, 26 insertions, 26 deletions
diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index d117e4b3..94c96201 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -7,7 +7,7 @@ use std::sync::Arc; pub use crate::models::stable_lm::Config; use crate::models::stable_lm::RotaryEmbedding; -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Linear, @@ -43,7 +43,7 @@ impl Module for MLP { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Attention { q_proj: Linear, k_proj: Linear, @@ -168,7 +168,7 @@ impl Attention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct DecoderLayer { self_attn: Attention, mlp: MLP, @@ -213,7 +213,7 @@ impl DecoderLayer { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Model { embed_tokens: Embedding, layers: Vec<DecoderLayer>, diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 1426df39..4e5bd81a 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -93,7 +93,7 @@ impl Default for Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerNorm { weight: Tensor, variance_epsilon: f64, @@ -125,7 +125,7 @@ impl Module for T5LayerNorm { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5DenseActDense { wi: QMatMul, wo: QMatMul, @@ -156,7 +156,7 @@ impl Module for T5DenseActDense { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5DenseGatedActDense { wi_0: QMatMul, wi_1: QMatMul, @@ -191,7 +191,7 @@ impl Module for T5DenseGatedActDense { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerFF { dense_act: Option<T5DenseActDense>, gated_dense_act: Option<T5DenseGatedActDense>, @@ -236,7 +236,7 @@ impl Module for T5LayerFF { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Attention { q: QMatMul, k: QMatMul, @@ -431,7 +431,7 @@ impl T5Attention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerSelfAttention { self_attention: T5Attention, layer_norm: T5LayerNorm, @@ -470,7 +470,7 @@ impl T5LayerSelfAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerCrossAttention { cross_attention: T5Attention, layer_norm: T5LayerNorm, @@ -512,7 +512,7 @@ impl T5LayerCrossAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Block { self_attn: T5LayerSelfAttention, cross_attn: Option<T5LayerCrossAttention>, @@ -583,7 +583,7 @@ impl T5Block { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Stack { block: Vec<T5Block>, shared: Arc<Embedding>, @@ -633,7 +633,7 @@ impl T5Stack { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct T5EncoderModel { encoder: T5Stack, device: Device, @@ -666,7 +666,7 @@ impl T5EncoderModel { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct T5ForConditionalGeneration { encoder: T5Stack, decoder: T5Stack, diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 9b3d97b8..1101d001 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -118,7 +118,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerNorm { weight: Tensor, variance_epsilon: f64, @@ -150,7 +150,7 @@ impl Module for T5LayerNorm { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5DenseActDense { wi: Linear, wo: Linear, @@ -181,7 +181,7 @@ impl Module for T5DenseActDense { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5DenseGatedActDense { wi_0: Linear, wi_1: Linear, @@ -216,7 +216,7 @@ impl Module for T5DenseGatedActDense { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerFF { dense_act: Option<T5DenseActDense>, gated_dense_act: Option<T5DenseGatedActDense>, @@ -261,7 +261,7 @@ impl Module for T5LayerFF { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Attention { q: Linear, k: Linear, @@ -456,7 +456,7 @@ impl T5Attention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerSelfAttention { self_attention: T5Attention, layer_norm: T5LayerNorm, @@ -495,7 +495,7 @@ impl T5LayerSelfAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerCrossAttention { cross_attention: T5Attention, layer_norm: T5LayerNorm, @@ -537,7 +537,7 @@ impl T5LayerCrossAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Block { self_attn: T5LayerSelfAttention, cross_attn: Option<T5LayerCrossAttention>, @@ -608,7 +608,7 @@ impl T5Block { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Stack { block: Vec<T5Block>, shared: Arc<Embedding>, @@ -658,7 +658,7 @@ impl T5Stack { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct T5EncoderModel { encoder: T5Stack, device: Device, @@ -691,7 +691,7 @@ impl T5EncoderModel { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct T5ForConditionalGeneration { encoder: T5Stack, decoder: T5Stack, |