summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/mistral.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/mistral.rs')
-rw-r--r--candle-transformers/src/models/mistral.rs12
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index e0ecee7b..caf96bce 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -39,7 +39,7 @@ impl Config {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
@@ -60,7 +60,7 @@ impl Module for RmsNorm {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
@@ -111,7 +111,7 @@ impl RotaryEmbedding {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
@@ -160,7 +160,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
unimplemented!("compile with '--features flash-attn'")
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
@@ -279,7 +279,7 @@ impl Attention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
@@ -322,7 +322,7 @@ impl DecoderLayer {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,