summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_llama.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/quantized_llama.rs')
-rw-r--r--candle-transformers/src/models/quantized_llama.rs4
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index 8ac1d460..44d89f40 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -7,6 +7,7 @@ use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
+#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::LayerNorm,
span: tracing::Span,
@@ -27,6 +28,7 @@ impl RmsNorm {
}
// QMatMul wrapper adding some tracing.
+#[derive(Debug, Clone)]
struct QMatMul {
inner: candle::quantized::QMatMul,
span: tracing::Span,
@@ -45,6 +47,7 @@ impl QMatMul {
}
}
+#[derive(Debug, Clone)]
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
@@ -167,6 +170,7 @@ impl LayerWeights {
}
}
+#[derive(Debug, Clone)]
pub struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,