summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_stable_lm.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-09 06:22:22 +0100
committerGitHub <noreply@github.com>2023-10-09 06:22:22 +0100
commit392fe02fba96658bafc73100e80bf68d54e4e23f (patch)
tree3c3f9ef5e663a374011c1c90bec8e0e2b6bb30f8 /candle-transformers/src/models/quantized_stable_lm.rs
parent59ab6d7832600083a1519aa0511e9c7c832ae01c (diff)
downloadcandle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.gz
candle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.bz2
candle-392fe02fba96658bafc73100e80bf68d54e4e23f.zip
Move the common quantized-nn code to a shared module. (#1063)
Diffstat (limited to 'candle-transformers/src/models/quantized_stable_lm.rs')
-rw-r--r--candle-transformers/src/models/quantized_stable_lm.rs25
1 files changed, 1 insertions, 24 deletions
diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs
index 86964237..304e91ee 100644
--- a/candle-transformers/src/models/quantized_stable_lm.rs
+++ b/candle-transformers/src/models/quantized_stable_lm.rs
@@ -1,5 +1,4 @@
-use crate::models::quantized_t5::Embedding;
-use crate::models::with_tracing::QMatMul;
+use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm};
@@ -9,28 +8,6 @@ pub use crate::models::stable_lm::Config;
use crate::models::stable_lm::RotaryEmbedding;
#[derive(Debug)]
-struct Linear {
- weight: QMatMul,
-}
-
-impl Module for Linear {
- fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- x.apply(&self.weight)
- }
-}
-
-fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = QMatMul::new(in_dim, out_dim, vb)?;
- Ok(Linear { weight })
-}
-
-fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
- let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
- let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
- Ok(candle_nn::LayerNorm::new(weight, bias, eps))
-}
-
-#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,