summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/quantized_mistral.rs10
-rw-r--r--candle-transformers/src/models/quantized_mixformer.rs14
-rw-r--r--candle-transformers/src/models/quantized_mpt.rs8
-rw-r--r--candle-transformers/src/quantized_nn.rs6
4 files changed, 19 insertions, 19 deletions
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs
index 00c80209..9e306c67 100644
--- a/candle-transformers/src/models/quantized_mistral.rs
+++ b/candle-transformers/src/models/quantized_mistral.rs
@@ -6,7 +6,7 @@ use std::sync::Arc;
pub use crate::models::mistral::Config;
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
@@ -57,7 +57,7 @@ impl RotaryEmbedding {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
@@ -90,7 +90,7 @@ impl Module for MLP {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
@@ -200,7 +200,7 @@ impl Attention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
@@ -243,7 +243,7 @@ impl DecoderLayer {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: Embedding,
layers: Vec<DecoderLayer>,
diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs
index 23eeb0ac..f11f2036 100644
--- a/candle-transformers/src/models/quantized_mixformer.rs
+++ b/candle-transformers/src/models/quantized_mixformer.rs
@@ -7,7 +7,7 @@ pub use crate::models::mixformer::Config;
const MAX_SEQ_LEN: usize = 4096;
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct Embedding {
wte: crate::quantized_nn::Embedding,
}
@@ -39,7 +39,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m)
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
@@ -105,7 +105,7 @@ impl RotaryEmbedding {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
fc1: Linear,
@@ -132,7 +132,7 @@ impl Module for MLP {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct CausalLMHead {
ln: candle_nn::LayerNorm,
linear: Linear,
@@ -154,7 +154,7 @@ impl Module for CausalLMHead {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MHA {
wqkv: Linear,
@@ -243,7 +243,7 @@ impl MHA {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct ParallelBlock {
ln: candle_nn::LayerNorm,
mixer: MHA,
@@ -278,7 +278,7 @@ impl ParallelBlock {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct MixFormerSequentialForCausalLM {
embedding: Embedding,
blocks: Vec<ParallelBlock>,
diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs
index 7586e4c0..70a9e125 100644
--- a/candle-transformers/src/models/quantized_mpt.rs
+++ b/candle-transformers/src/models/quantized_mpt.rs
@@ -7,7 +7,7 @@ use candle_nn::LayerNorm;
pub use super::mpt::Config;
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct GroupedQueryAttention {
wqkv: Linear,
out_proj: Linear,
@@ -101,7 +101,7 @@ impl GroupedQueryAttention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct Ffn {
up_proj: Linear,
down_proj: Linear,
@@ -122,7 +122,7 @@ impl Module for Ffn {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct MPTBlock {
norm1: LayerNorm, // Do we need the low-precision variant?
attn: GroupedQueryAttention,
@@ -155,7 +155,7 @@ impl MPTBlock {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Model {
wte: Embedding,
blocks: Vec<MPTBlock>,
diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs
index d71c3b60..2941c3f0 100644
--- a/candle-transformers/src/quantized_nn.rs
+++ b/candle-transformers/src/quantized_nn.rs
@@ -2,7 +2,7 @@ use crate::models::with_tracing::QMatMul;
use crate::quantized_var_builder::VarBuilder;
use candle::{Module, Result, Tensor};
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Embedding {
inner: candle_nn::Embedding,
span: tracing::Span,
@@ -28,7 +28,7 @@ impl Module for Embedding {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Linear {
weight: QMatMul,
bias: Option<Tensor>,
@@ -69,7 +69,7 @@ pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<L
Ok(Linear { weight, bias: None })
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,