From 185b54a33bae51410a667dbb212ba6f29bb6104f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 18 Oct 2023 19:30:47 +0100 Subject: Make some model cloneable. (#1125) --- candle-transformers/src/models/mistral.rs | 12 ++++++------ candle-transformers/src/models/mixformer.rs | 14 +++++++------- candle-transformers/src/models/mpt.rs | 8 ++++---- candle-transformers/src/models/quantized_llama.rs | 4 ++++ candle-transformers/src/models/with_tracing.rs | 7 ++++--- 5 files changed, 25 insertions(+), 20 deletions(-) (limited to 'candle-transformers') diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index e0ecee7b..caf96bce 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -39,7 +39,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct RmsNorm { inner: candle_nn::RmsNorm, span: tracing::Span, @@ -60,7 +60,7 @@ impl Module for RmsNorm { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, @@ -111,7 +111,7 @@ impl RotaryEmbedding { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Linear, @@ -160,7 +160,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result, diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index f1fd8256..0f2c199b 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -74,7 +74,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Embedding { wte: E, } @@ -106,7 +106,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result Ok(m) } -#[derive(Debug)] +#[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, @@ -172,7 +172,7 @@ impl RotaryEmbedding { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { fc1: Linear, @@ -199,7 +199,7 @@ impl Module for MLP { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct CausalLMHead { ln: candle_nn::LayerNorm, linear: Linear, @@ -221,7 +221,7 @@ impl Module for CausalLMHead { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MHA { wqkv: Linear, @@ -310,7 +310,7 @@ impl MHA { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct ParallelBlock { ln: candle_nn::LayerNorm, mixer: MHA, @@ -345,7 +345,7 @@ impl ParallelBlock { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MixFormerSequentialForCausalLM { embedding: Embedding, blocks: Vec, diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index 0d91bf94..093e177c 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -40,7 +40,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct GroupedQueryAttention { wqkv: Linear, out_proj: Linear, @@ -148,7 +148,7 @@ pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Ffn { up_proj: Linear, down_proj: Linear, @@ -169,7 +169,7 @@ impl Module for Ffn { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct MPTBlock { norm1: LayerNorm, // Do we need the low-precision variant? attn: GroupedQueryAttention, @@ -240,7 +240,7 @@ pub(crate) fn build_alibi_bias(cfg: &Config) -> Result { alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes) } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Model { wte: Embedding, blocks: Vec, diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 8ac1d460..44d89f40 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -7,6 +7,7 @@ use candle_nn::{Embedding, Module}; pub const MAX_SEQ_LEN: usize = 4096; +#[derive(Debug, Clone)] struct RmsNorm { inner: candle_nn::LayerNorm, span: tracing::Span, @@ -27,6 +28,7 @@ impl RmsNorm { } // QMatMul wrapper adding some tracing. +#[derive(Debug, Clone)] struct QMatMul { inner: candle::quantized::QMatMul, span: tracing::Span, @@ -45,6 +47,7 @@ impl QMatMul { } } +#[derive(Debug, Clone)] struct LayerWeights { attention_wq: QMatMul, attention_wk: QMatMul, @@ -167,6 +170,7 @@ impl LayerWeights { } } +#[derive(Debug, Clone)] pub struct ModelWeights { tok_embeddings: Embedding, layers: Vec, diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 09a243ac..edd8d657 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -1,7 +1,7 @@ use candle::{Module, Result, Tensor}; use candle_nn::VarBuilder; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Embedding { inner: candle_nn::Embedding, span: tracing::Span, @@ -26,7 +26,7 @@ impl Module for Embedding { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Linear { inner: candle_nn::Linear, span: tracing::Span, @@ -52,7 +52,7 @@ impl Module for Linear { } // Wrap the conv2d op to provide some tracing. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Conv2d { inner: candle_nn::Conv2d, span: tracing::Span, @@ -78,6 +78,7 @@ pub fn conv2d( } // QMatMul wrapper adding some tracing. +#[derive(Clone)] pub struct QMatMul { inner: candle::quantized::QMatMul, span: tracing::Span, -- cgit v1.2.3