summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_mistral.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/quantized_mistral.rs')
-rw-r--r--candle-transformers/src/models/quantized_mistral.rs41
1 files changed, 1 insertions, 40 deletions
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs
index 171e7440..00c80209 100644
--- a/candle-transformers/src/models/quantized_mistral.rs
+++ b/candle-transformers/src/models/quantized_mistral.rs
@@ -1,5 +1,4 @@
-use crate::models::quantized_t5::Embedding;
-use crate::models::with_tracing::QMatMul;
+use crate::quantized_nn::{linear_no_bias, Embedding, Linear, RmsNorm};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::Activation;
@@ -8,44 +7,6 @@ use std::sync::Arc;
pub use crate::models::mistral::Config;
#[derive(Debug)]
-struct Linear {
- weight: QMatMul,
-}
-
-impl Module for Linear {
- fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- x.apply(&self.weight)
- }
-}
-
-fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = QMatMul::new(in_dim, out_dim, vb)?;
- Ok(Linear { weight })
-}
-
-#[derive(Debug)]
-struct RmsNorm {
- inner: candle_nn::RmsNorm,
- span: tracing::Span,
-}
-
-impl RmsNorm {
- fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
- let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
- let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
- let inner = candle_nn::RmsNorm::new(weight, eps);
- Ok(Self { inner, span })
- }
-}
-
-impl Module for RmsNorm {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
-#[derive(Debug)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,