diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-09 06:22:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-09 06:22:22 +0100 |
commit | 392fe02fba96658bafc73100e80bf68d54e4e23f (patch) | |
tree | 3c3f9ef5e663a374011c1c90bec8e0e2b6bb30f8 /candle-transformers/src/quantized_nn.rs | |
parent | 59ab6d7832600083a1519aa0511e9c7c832ae01c (diff) | |
download | candle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.gz candle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.bz2 candle-392fe02fba96658bafc73100e80bf68d54e4e23f.zip |
Move the common quantized-nn code to a shared module. (#1063)
Diffstat (limited to 'candle-transformers/src/quantized_nn.rs')
-rw-r--r-- | candle-transformers/src/quantized_nn.rs | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs new file mode 100644 index 00000000..1745327d --- /dev/null +++ b/candle-transformers/src/quantized_nn.rs @@ -0,0 +1,87 @@ +use crate::models::with_tracing::QMatMul; +use crate::quantized_var_builder::VarBuilder; +use candle::{Module, Result, Tensor}; + +#[derive(Debug)] +pub struct Embedding { + inner: candle_nn::Embedding, + span: tracing::Span, +} + +impl Embedding { + pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> { + let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?; + let inner = candle_nn::Embedding::new(embeddings, d2); + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + + pub fn embeddings(&self) -> &Tensor { + self.inner.embeddings() + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +#[derive(Debug)] +pub struct Linear { + weight: QMatMul, + bias: Option<Tensor>, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { + let x = x.apply(&self.weight)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { + let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { + weight, + bias: Some(bias), + }) +} + +pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> { + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + let bias = vb.get(size, "bias")?.dequantize(vb.device())?; + Ok(candle_nn::LayerNorm::new(weight, bias, eps)) +} + +pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { weight, bias: None }) +} + +#[derive(Debug)] +pub struct RmsNorm { + inner: candle_nn::RmsNorm, + span: tracing::Span, +} + +impl RmsNorm { + pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + let inner = candle_nn::RmsNorm::new(weight, eps); + Ok(Self { inner, span }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} |