summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_mpt.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/quantized_mpt.rs')
-rw-r--r--candle-transformers/src/models/quantized_mpt.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs
index 7586e4c0..70a9e125 100644
--- a/candle-transformers/src/models/quantized_mpt.rs
+++ b/candle-transformers/src/models/quantized_mpt.rs
@@ -7,7 +7,7 @@ use candle_nn::LayerNorm;
pub use super::mpt::Config;
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct GroupedQueryAttention {
wqkv: Linear,
out_proj: Linear,
@@ -101,7 +101,7 @@ impl GroupedQueryAttention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct Ffn {
up_proj: Linear,
down_proj: Linear,
@@ -122,7 +122,7 @@ impl Module for Ffn {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct MPTBlock {
norm1: LayerNorm, // Do we need the low-precision variant?
attn: GroupedQueryAttention,
@@ -155,7 +155,7 @@ impl MPTBlock {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Model {
wte: Embedding,
blocks: Vec<MPTBlock>,