summaryrefslogtreecommitdiff
path: root/candle-transformers/src/quantized_nn.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-09 06:22:22 +0100
committerGitHub <noreply@github.com>2023-10-09 06:22:22 +0100
commit392fe02fba96658bafc73100e80bf68d54e4e23f (patch)
tree3c3f9ef5e663a374011c1c90bec8e0e2b6bb30f8 /candle-transformers/src/quantized_nn.rs
parent59ab6d7832600083a1519aa0511e9c7c832ae01c (diff)
downloadcandle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.gz
candle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.bz2
candle-392fe02fba96658bafc73100e80bf68d54e4e23f.zip
Move the common quantized-nn code to a shared module. (#1063)
Diffstat (limited to 'candle-transformers/src/quantized_nn.rs')
-rw-r--r--candle-transformers/src/quantized_nn.rs87
1 files changed, 87 insertions, 0 deletions
diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs
new file mode 100644
index 00000000..1745327d
--- /dev/null
+++ b/candle-transformers/src/quantized_nn.rs
@@ -0,0 +1,87 @@
+use crate::models::with_tracing::QMatMul;
+use crate::quantized_var_builder::VarBuilder;
+use candle::{Module, Result, Tensor};
+
+#[derive(Debug)]
+pub struct Embedding {
+ inner: candle_nn::Embedding,
+ span: tracing::Span,
+}
+
+impl Embedding {
+ pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
+ let inner = candle_nn::Embedding::new(embeddings, d2);
+ let span = tracing::span!(tracing::Level::TRACE, "embedding");
+ Ok(Self { inner, span })
+ }
+
+ pub fn embeddings(&self) -> &Tensor {
+ self.inner.embeddings()
+ }
+}
+
+impl Module for Embedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+ weight: QMatMul,
+ bias: Option<Tensor>,
+}
+
+impl Module for Linear {
+ fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
+ let x = x.apply(&self.weight)?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => x.broadcast_add(bias),
+ }
+ }
+}
+
+pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
+ let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
+ let weight = QMatMul::new(in_dim, out_dim, vb)?;
+ Ok(Linear {
+ weight,
+ bias: Some(bias),
+ })
+}
+
+pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
+ let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
+ let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
+ Ok(candle_nn::LayerNorm::new(weight, bias, eps))
+}
+
+pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = QMatMul::new(in_dim, out_dim, vb)?;
+ Ok(Linear { weight, bias: None })
+}
+
+#[derive(Debug)]
+pub struct RmsNorm {
+ inner: candle_nn::RmsNorm,
+ span: tracing::Span,
+}
+
+impl RmsNorm {
+ pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
+ let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
+ let inner = candle_nn::RmsNorm::new(weight, eps);
+ Ok(Self { inner, span })
+ }
+}
+
+impl Module for RmsNorm {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}