diff options
Diffstat (limited to 'candle-transformers/src/models/mpt.rs')
-rw-r--r-- | candle-transformers/src/models/mpt.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index 0d91bf94..093e177c 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -40,7 +40,7 @@ impl Config { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct GroupedQueryAttention { wqkv: Linear, out_proj: Linear, @@ -148,7 +148,7 @@ pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Ffn { up_proj: Linear, down_proj: Linear, @@ -169,7 +169,7 @@ impl Module for Ffn { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct MPTBlock { norm1: LayerNorm, // Do we need the low-precision variant? attn: GroupedQueryAttention, @@ -240,7 +240,7 @@ pub(crate) fn build_alibi_bias(cfg: &Config) -> Result<Tensor> { alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes) } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Model { wte: Embedding, blocks: Vec<MPTBlock>, |