diff options
Diffstat (limited to 'candle-transformers/src/models/quantized_llama.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_llama.rs | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 8ac1d460..44d89f40 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -7,6 +7,7 @@ use candle_nn::{Embedding, Module}; pub const MAX_SEQ_LEN: usize = 4096; +#[derive(Debug, Clone)] struct RmsNorm { inner: candle_nn::LayerNorm, span: tracing::Span, @@ -27,6 +28,7 @@ impl RmsNorm { } // QMatMul wrapper adding some tracing. +#[derive(Debug, Clone)] struct QMatMul { inner: candle::quantized::QMatMul, span: tracing::Span, @@ -45,6 +47,7 @@ impl QMatMul { } } +#[derive(Debug, Clone)] struct LayerWeights { attention_wq: QMatMul, attention_wk: QMatMul, @@ -167,6 +170,7 @@ impl LayerWeights { } } +#[derive(Debug, Clone)] pub struct ModelWeights { tok_embeddings: Embedding, layers: Vec<LayerWeights>, |