diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-09 06:22:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-09 06:22:22 +0100 |
commit | 392fe02fba96658bafc73100e80bf68d54e4e23f (patch) | |
tree | 3c3f9ef5e663a374011c1c90bec8e0e2b6bb30f8 /candle-transformers/src/models/quantized_stable_lm.rs | |
parent | 59ab6d7832600083a1519aa0511e9c7c832ae01c (diff) | |
download | candle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.gz candle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.bz2 candle-392fe02fba96658bafc73100e80bf68d54e4e23f.zip |
Move the common quantized-nn code to a shared module. (#1063)
Diffstat (limited to 'candle-transformers/src/models/quantized_stable_lm.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_stable_lm.rs | 25 |
1 files changed, 1 insertions, 24 deletions
diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index 86964237..304e91ee 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -1,5 +1,4 @@ -use crate::models::quantized_t5::Embedding; -use crate::models::with_tracing::QMatMul; +use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm}; @@ -9,28 +8,6 @@ pub use crate::models::stable_lm::Config; use crate::models::stable_lm::RotaryEmbedding; #[derive(Debug)] -struct Linear { - weight: QMatMul, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - x.apply(&self.weight) - } -} - -fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { - let weight = QMatMul::new(in_dim, out_dim, vb)?; - Ok(Linear { weight }) -} - -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> { - let weight = vb.get(size, "weight")?.dequantize(vb.device())?; - let bias = vb.get(size, "bias")?.dequantize(vb.device())?; - Ok(candle_nn::LayerNorm::new(weight, bias, eps)) -} - -#[derive(Debug)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Linear, |