diff options
Diffstat (limited to 'candle-transformers/src/models/mistral.rs')
-rw-r--r-- | candle-transformers/src/models/mistral.rs | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index e0ecee7b..caf96bce 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -39,7 +39,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct RmsNorm { inner: candle_nn::RmsNorm, span: tracing::Span, @@ -60,7 +60,7 @@ impl Module for RmsNorm { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, @@ -111,7 +111,7 @@ impl RotaryEmbedding { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Linear, @@ -160,7 +160,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten unimplemented!("compile with '--features flash-attn'") } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Attention { q_proj: Linear, k_proj: Linear, @@ -279,7 +279,7 @@ impl Attention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct DecoderLayer { self_attn: Attention, mlp: MLP, @@ -322,7 +322,7 @@ impl DecoderLayer { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Model { embed_tokens: candle_nn::Embedding, layers: Vec<DecoderLayer>, |