summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-18 19:30:47 +0100
committerGitHub <noreply@github.com>2023-10-18 19:30:47 +0100
commit185b54a33bae51410a667dbb212ba6f29bb6104f (patch)
treea7440d6376cb142c62800fe66f72a3378b6b6ba7 /candle-transformers
parent620c94d12e3fe37b3ca5fb7017e7208fdd955365 (diff)
downloadcandle-185b54a33bae51410a667dbb212ba6f29bb6104f.tar.gz
candle-185b54a33bae51410a667dbb212ba6f29bb6104f.tar.bz2
candle-185b54a33bae51410a667dbb212ba6f29bb6104f.zip
Make some model cloneable. (#1125)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/mistral.rs12
-rw-r--r--candle-transformers/src/models/mixformer.rs14
-rw-r--r--candle-transformers/src/models/mpt.rs8
-rw-r--r--candle-transformers/src/models/quantized_llama.rs4
-rw-r--r--candle-transformers/src/models/with_tracing.rs7
5 files changed, 25 insertions, 20 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index e0ecee7b..caf96bce 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -39,7 +39,7 @@ impl Config {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
@@ -60,7 +60,7 @@ impl Module for RmsNorm {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
@@ -111,7 +111,7 @@ impl RotaryEmbedding {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
@@ -160,7 +160,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
unimplemented!("compile with '--features flash-attn'")
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
@@ -279,7 +279,7 @@ impl Attention {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
@@ -322,7 +322,7 @@ impl DecoderLayer {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs
index f1fd8256..0f2c199b 100644
--- a/candle-transformers/src/models/mixformer.rs
+++ b/candle-transformers/src/models/mixformer.rs
@@ -74,7 +74,7 @@ impl Config {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct Embedding {
wte: E,
}
@@ -106,7 +106,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m)
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
@@ -172,7 +172,7 @@ impl RotaryEmbedding {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
fc1: Linear,
@@ -199,7 +199,7 @@ impl Module for MLP {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct CausalLMHead {
ln: candle_nn::LayerNorm,
linear: Linear,
@@ -221,7 +221,7 @@ impl Module for CausalLMHead {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MHA {
wqkv: Linear,
@@ -310,7 +310,7 @@ impl MHA {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct ParallelBlock {
ln: candle_nn::LayerNorm,
mixer: MHA,
@@ -345,7 +345,7 @@ impl ParallelBlock {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct MixFormerSequentialForCausalLM {
embedding: Embedding,
blocks: Vec<ParallelBlock>,
diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs
index 0d91bf94..093e177c 100644
--- a/candle-transformers/src/models/mpt.rs
+++ b/candle-transformers/src/models/mpt.rs
@@ -40,7 +40,7 @@ impl Config {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct GroupedQueryAttention {
wqkv: Linear,
out_proj: Linear,
@@ -148,7 +148,7 @@ pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct Ffn {
up_proj: Linear,
down_proj: Linear,
@@ -169,7 +169,7 @@ impl Module for Ffn {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
struct MPTBlock {
norm1: LayerNorm, // Do we need the low-precision variant?
attn: GroupedQueryAttention,
@@ -240,7 +240,7 @@ pub(crate) fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Model {
wte: Embedding,
blocks: Vec<MPTBlock>,
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index 8ac1d460..44d89f40 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -7,6 +7,7 @@ use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
+#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::LayerNorm,
span: tracing::Span,
@@ -27,6 +28,7 @@ impl RmsNorm {
}
// QMatMul wrapper adding some tracing.
+#[derive(Debug, Clone)]
struct QMatMul {
inner: candle::quantized::QMatMul,
span: tracing::Span,
@@ -45,6 +47,7 @@ impl QMatMul {
}
}
+#[derive(Debug, Clone)]
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
@@ -167,6 +170,7 @@ impl LayerWeights {
}
}
+#[derive(Debug, Clone)]
pub struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,
diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs
index 09a243ac..edd8d657 100644
--- a/candle-transformers/src/models/with_tracing.rs
+++ b/candle-transformers/src/models/with_tracing.rs
@@ -1,7 +1,7 @@
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Embedding {
inner: candle_nn::Embedding,
span: tracing::Span,
@@ -26,7 +26,7 @@ impl Module for Embedding {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
@@ -52,7 +52,7 @@ impl Module for Linear {
}
// Wrap the conv2d op to provide some tracing.
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Conv2d {
inner: candle_nn::Conv2d,
span: tracing::Span,
@@ -78,6 +78,7 @@ pub fn conv2d(
}
// QMatMul wrapper adding some tracing.
+#[derive(Clone)]
pub struct QMatMul {
inner: candle::quantized::QMatMul,
span: tracing::Span,