summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-28 08:43:08 +0200
committerGitHub <noreply@github.com>2023-10-28 07:43:08 +0100
commit612f5b81561150ca6651368c245ac2065c04159a (patch)
tree2c80848778d4d67dce7e7f5803155e5cc44c0f57 /candle-transformers
parentef33df7ae2b94e2b911b61f3765d6826726614e7 (diff)
downloadcandle-612f5b81561150ca6651368c245ac2065c04159a.tar.gz
candle-612f5b81561150ca6651368c245ac2065c04159a.tar.bz2
candle-612f5b81561150ca6651368c245ac2065c04159a.zip
Make more models cloneable. (#1203)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/quantized_stable_lm.rs8
-rw-r--r--candle-transformers/src/models/quantized_t5.rs22
-rw-r--r--candle-transformers/src/models/t5.rs22
3 files changed, 26 insertions, 26 deletions
diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs
index d117e4b3..94c96201 100644
--- a/candle-transformers/src/models/quantized_stable_lm.rs
+++ b/candle-transformers/src/models/quantized_stable_lm.rs
@@ -7,7 +7,7 @@ use std::sync::Arc;
pub use crate::models::stable_lm::Config;
use crate::models::stable_lm::RotaryEmbedding;
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
@@ -43,7 +43,7 @@ impl Module for MLP {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
@@ -168,7 +168,7 @@ impl Attention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
@@ -213,7 +213,7 @@ impl DecoderLayer {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: Embedding,
layers: Vec<DecoderLayer>,
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs
index 1426df39..4e5bd81a 100644
--- a/candle-transformers/src/models/quantized_t5.rs
+++ b/candle-transformers/src/models/quantized_t5.rs
@@ -93,7 +93,7 @@ impl Default for Config {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerNorm {
weight: Tensor,
variance_epsilon: f64,
@@ -125,7 +125,7 @@ impl Module for T5LayerNorm {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5DenseActDense {
wi: QMatMul,
wo: QMatMul,
@@ -156,7 +156,7 @@ impl Module for T5DenseActDense {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5DenseGatedActDense {
wi_0: QMatMul,
wi_1: QMatMul,
@@ -191,7 +191,7 @@ impl Module for T5DenseGatedActDense {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerFF {
dense_act: Option<T5DenseActDense>,
gated_dense_act: Option<T5DenseGatedActDense>,
@@ -236,7 +236,7 @@ impl Module for T5LayerFF {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5Attention {
q: QMatMul,
k: QMatMul,
@@ -431,7 +431,7 @@ impl T5Attention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerSelfAttention {
self_attention: T5Attention,
layer_norm: T5LayerNorm,
@@ -470,7 +470,7 @@ impl T5LayerSelfAttention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerCrossAttention {
cross_attention: T5Attention,
layer_norm: T5LayerNorm,
@@ -512,7 +512,7 @@ impl T5LayerCrossAttention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5Block {
self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>,
@@ -583,7 +583,7 @@ impl T5Block {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5Stack {
block: Vec<T5Block>,
shared: Arc<Embedding>,
@@ -633,7 +633,7 @@ impl T5Stack {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct T5EncoderModel {
encoder: T5Stack,
device: Device,
@@ -666,7 +666,7 @@ impl T5EncoderModel {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct T5ForConditionalGeneration {
encoder: T5Stack,
decoder: T5Stack,
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 9b3d97b8..1101d001 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -118,7 +118,7 @@ impl Config {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerNorm {
weight: Tensor,
variance_epsilon: f64,
@@ -150,7 +150,7 @@ impl Module for T5LayerNorm {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5DenseActDense {
wi: Linear,
wo: Linear,
@@ -181,7 +181,7 @@ impl Module for T5DenseActDense {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5DenseGatedActDense {
wi_0: Linear,
wi_1: Linear,
@@ -216,7 +216,7 @@ impl Module for T5DenseGatedActDense {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerFF {
dense_act: Option<T5DenseActDense>,
gated_dense_act: Option<T5DenseGatedActDense>,
@@ -261,7 +261,7 @@ impl Module for T5LayerFF {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5Attention {
q: Linear,
k: Linear,
@@ -456,7 +456,7 @@ impl T5Attention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerSelfAttention {
self_attention: T5Attention,
layer_norm: T5LayerNorm,
@@ -495,7 +495,7 @@ impl T5LayerSelfAttention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5LayerCrossAttention {
cross_attention: T5Attention,
layer_norm: T5LayerNorm,
@@ -537,7 +537,7 @@ impl T5LayerCrossAttention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5Block {
self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>,
@@ -608,7 +608,7 @@ impl T5Block {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct T5Stack {
block: Vec<T5Block>,
shared: Arc<Embedding>,
@@ -658,7 +658,7 @@ impl T5Stack {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct T5EncoderModel {
encoder: T5Stack,
device: Device,
@@ -691,7 +691,7 @@ impl T5EncoderModel {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct T5ForConditionalGeneration {
encoder: T5Stack,
decoder: T5Stack,