diff options
Diffstat (limited to 'candle-transformers/src/models/t5.rs')
-rw-r--r-- | candle-transformers/src/models/t5.rs | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 9b3d97b8..1101d001 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -118,7 +118,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerNorm { weight: Tensor, variance_epsilon: f64, @@ -150,7 +150,7 @@ impl Module for T5LayerNorm { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5DenseActDense { wi: Linear, wo: Linear, @@ -181,7 +181,7 @@ impl Module for T5DenseActDense { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5DenseGatedActDense { wi_0: Linear, wi_1: Linear, @@ -216,7 +216,7 @@ impl Module for T5DenseGatedActDense { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerFF { dense_act: Option<T5DenseActDense>, gated_dense_act: Option<T5DenseGatedActDense>, @@ -261,7 +261,7 @@ impl Module for T5LayerFF { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Attention { q: Linear, k: Linear, @@ -456,7 +456,7 @@ impl T5Attention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerSelfAttention { self_attention: T5Attention, layer_norm: T5LayerNorm, @@ -495,7 +495,7 @@ impl T5LayerSelfAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5LayerCrossAttention { cross_attention: T5Attention, layer_norm: T5LayerNorm, @@ -537,7 +537,7 @@ impl T5LayerCrossAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Block { self_attn: T5LayerSelfAttention, cross_attn: Option<T5LayerCrossAttention>, @@ -608,7 +608,7 @@ impl T5Block { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct T5Stack { block: Vec<T5Block>, shared: Arc<Embedding>, @@ -658,7 +658,7 @@ impl T5Stack { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct T5EncoderModel { encoder: T5Stack, device: Device, @@ -691,7 +691,7 @@ impl T5EncoderModel { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct T5ForConditionalGeneration { encoder: T5Stack, decoder: T5Stack, |