diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-18 19:30:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-18 19:30:47 +0100 |
commit | 185b54a33bae51410a667dbb212ba6f29bb6104f (patch) | |
tree | a7440d6376cb142c62800fe66f72a3378b6b6ba7 /candle-transformers | |
parent | 620c94d12e3fe37b3ca5fb7017e7208fdd955365 (diff) | |
download | candle-185b54a33bae51410a667dbb212ba6f29bb6104f.tar.gz candle-185b54a33bae51410a667dbb212ba6f29bb6104f.tar.bz2 candle-185b54a33bae51410a667dbb212ba6f29bb6104f.zip |
Make some model cloneable. (#1125)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/mistral.rs | 12 | ||||
-rw-r--r-- | candle-transformers/src/models/mixformer.rs | 14 | ||||
-rw-r--r-- | candle-transformers/src/models/mpt.rs | 8 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_llama.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/with_tracing.rs | 7 |
5 files changed, 25 insertions, 20 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index e0ecee7b..caf96bce 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -39,7 +39,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct RmsNorm { inner: candle_nn::RmsNorm, span: tracing::Span, @@ -60,7 +60,7 @@ impl Module for RmsNorm { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, @@ -111,7 +111,7 @@ impl RotaryEmbedding { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Linear, @@ -160,7 +160,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten unimplemented!("compile with '--features flash-attn'") } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Attention { q_proj: Linear, k_proj: Linear, @@ -279,7 +279,7 @@ impl Attention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct DecoderLayer { self_attn: Attention, mlp: MLP, @@ -322,7 +322,7 @@ impl DecoderLayer { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Model { embed_tokens: candle_nn::Embedding, layers: Vec<DecoderLayer>, diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index f1fd8256..0f2c199b 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -74,7 +74,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Embedding { wte: E, } @@ -106,7 +106,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> Ok(m) } -#[derive(Debug)] +#[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, @@ -172,7 +172,7 @@ impl RotaryEmbedding { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { fc1: Linear, @@ -199,7 +199,7 @@ impl Module for MLP { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct CausalLMHead { ln: candle_nn::LayerNorm, linear: Linear, @@ -221,7 +221,7 @@ impl Module for CausalLMHead { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MHA { wqkv: Linear, @@ -310,7 +310,7 @@ impl MHA { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct ParallelBlock { ln: candle_nn::LayerNorm, mixer: MHA, @@ -345,7 +345,7 @@ impl ParallelBlock { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MixFormerSequentialForCausalLM { embedding: Embedding, blocks: Vec<ParallelBlock>, diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index 0d91bf94..093e177c 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -40,7 +40,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct GroupedQueryAttention { wqkv: Linear, out_proj: Linear, @@ -148,7 +148,7 @@ pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Ffn { up_proj: Linear, down_proj: Linear, @@ -169,7 +169,7 @@ impl Module for Ffn { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct MPTBlock { norm1: LayerNorm, // Do we need the low-precision variant? attn: GroupedQueryAttention, @@ -240,7 +240,7 @@ pub(crate) fn build_alibi_bias(cfg: &Config) -> Result<Tensor> { alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes) } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Model { wte: Embedding, blocks: Vec<MPTBlock>, diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 8ac1d460..44d89f40 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -7,6 +7,7 @@ use candle_nn::{Embedding, Module}; pub const MAX_SEQ_LEN: usize = 4096; +#[derive(Debug, Clone)] struct RmsNorm { inner: candle_nn::LayerNorm, span: tracing::Span, @@ -27,6 +28,7 @@ impl RmsNorm { } // QMatMul wrapper adding some tracing. +#[derive(Debug, Clone)] struct QMatMul { inner: candle::quantized::QMatMul, span: tracing::Span, @@ -45,6 +47,7 @@ impl QMatMul { } } +#[derive(Debug, Clone)] struct LayerWeights { attention_wq: QMatMul, attention_wk: QMatMul, @@ -167,6 +170,7 @@ impl LayerWeights { } } +#[derive(Debug, Clone)] pub struct ModelWeights { tok_embeddings: Embedding, layers: Vec<LayerWeights>, diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 09a243ac..edd8d657 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -1,7 +1,7 @@ use candle::{Module, Result, Tensor}; use candle_nn::VarBuilder; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Embedding { inner: candle_nn::Embedding, span: tracing::Span, @@ -26,7 +26,7 @@ impl Module for Embedding { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Linear { inner: candle_nn::Linear, span: tracing::Span, @@ -52,7 +52,7 @@ impl Module for Linear { } // Wrap the conv2d op to provide some tracing. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Conv2d { inner: candle_nn::Conv2d, span: tracing::Span, @@ -78,6 +78,7 @@ pub fn conv2d( } // QMatMul wrapper adding some tracing. +#[derive(Clone)] pub struct QMatMul { inner: candle::quantized::QMatMul, span: tracing::Span, |