From 392fe02fba96658bafc73100e80bf68d54e4e23f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 9 Oct 2023 06:22:22 +0100 Subject: Move the common quantized-nn code to a shared module. (#1063) --- .../src/models/quantized_stable_lm.rs | 25 +--------------------- 1 file changed, 1 insertion(+), 24 deletions(-) (limited to 'candle-transformers/src/models/quantized_stable_lm.rs') diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index 86964237..304e91ee 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -1,5 +1,4 @@ -use crate::models::quantized_t5::Embedding; -use crate::models::with_tracing::QMatMul; +use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm}; @@ -8,28 +7,6 @@ use std::sync::Arc; pub use crate::models::stable_lm::Config; use crate::models::stable_lm::RotaryEmbedding; -#[derive(Debug)] -struct Linear { - weight: QMatMul, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> candle::Result { - x.apply(&self.weight) - } -} - -fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { - let weight = QMatMul::new(in_dim, out_dim, vb)?; - Ok(Linear { weight }) -} - -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { - let weight = vb.get(size, "weight")?.dequantize(vb.device())?; - let bias = vb.get(size, "bias")?.dequantize(vb.device())?; - Ok(candle_nn::LayerNorm::new(weight, bias, eps)) -} - #[derive(Debug)] #[allow(clippy::upper_case_acronyms)] struct MLP { -- cgit v1.2.3