diff options
Diffstat (limited to 'candle-transformers/src/models/quantized_mistral.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_mistral.rs | 41 |
1 files changed, 1 insertions, 40 deletions
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 171e7440..00c80209 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -1,5 +1,4 @@ -use crate::models::quantized_t5::Embedding; -use crate::models::with_tracing::QMatMul; +use crate::quantized_nn::{linear_no_bias, Embedding, Linear, RmsNorm}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::Activation; @@ -8,44 +7,6 @@ use std::sync::Arc; pub use crate::models::mistral::Config; #[derive(Debug)] -struct Linear { - weight: QMatMul, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - x.apply(&self.weight) - } -} - -fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { - let weight = QMatMul::new(in_dim, out_dim, vb)?; - Ok(Linear { weight }) -} - -#[derive(Debug)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let weight = vb.get(size, "weight")?.dequantize(vb.device())?; - let inner = candle_nn::RmsNorm::new(weight, eps); - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -#[derive(Debug)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, |