summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/t5.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/t5.rs')
-rw-r--r--candle-transformers/src/models/t5.rs22
1 files changed, 11 insertions, 11 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 9b3d97b8..1101d001 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -118,7 +118,7 @@ impl Config {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerNorm {
weight: Tensor,
variance_epsilon: f64,
@@ -150,7 +150,7 @@ impl Module for T5LayerNorm {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5DenseActDense {
wi: Linear,
wo: Linear,
@@ -181,7 +181,7 @@ impl Module for T5DenseActDense {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5DenseGatedActDense {
wi_0: Linear,
wi_1: Linear,
@@ -216,7 +216,7 @@ impl Module for T5DenseGatedActDense {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerFF {
dense_act: Option<T5DenseActDense>,
gated_dense_act: Option<T5DenseGatedActDense>,
@@ -261,7 +261,7 @@ impl Module for T5LayerFF {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5Attention {
q: Linear,
k: Linear,
@@ -456,7 +456,7 @@ impl T5Attention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerSelfAttention {
self_attention: T5Attention,
layer_norm: T5LayerNorm,
@@ -495,7 +495,7 @@ impl T5LayerSelfAttention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerCrossAttention {
cross_attention: T5Attention,
layer_norm: T5LayerNorm,
@@ -537,7 +537,7 @@ impl T5LayerCrossAttention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5Block {
self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>,
@@ -608,7 +608,7 @@ impl T5Block {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5Stack {
block: Vec<T5Block>,
shared: Arc<Embedding>,
@@ -658,7 +658,7 @@ impl T5Stack {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct T5EncoderModel {
encoder: T5Stack,
device: Device,
@@ -691,7 +691,7 @@ impl T5EncoderModel {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct T5ForConditionalGeneration {
encoder: T5Stack,
decoder: T5Stack,