summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-09 06:22:22 +0100
committerGitHub <noreply@github.com>2023-10-09 06:22:22 +0100
commit392fe02fba96658bafc73100e80bf68d54e4e23f (patch)
tree3c3f9ef5e663a374011c1c90bec8e0e2b6bb30f8
parent59ab6d7832600083a1519aa0511e9c7c832ae01c (diff)
downloadcandle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.gz
candle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.bz2
candle-392fe02fba96658bafc73100e80bf68d54e4e23f.zip
Move the common quantized-nn code to a shared module. (#1063)
-rw-r--r--candle-transformers/src/lib.rs1
-rw-r--r--candle-transformers/src/models/quantized_mistral.rs41
-rw-r--r--candle-transformers/src/models/quantized_mixformer.rs37
-rw-r--r--candle-transformers/src/models/quantized_stable_lm.rs25
-rw-r--r--candle-transformers/src/models/quantized_t5.rs27
-rw-r--r--candle-transformers/src/models/whisper/quantized_model.rs48
-rw-r--r--candle-transformers/src/quantized_nn.rs87
7 files changed, 100 insertions, 166 deletions
diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs
index a4c7ddf7..b2b062a9 100644
--- a/candle-transformers/src/lib.rs
+++ b/candle-transformers/src/lib.rs
@@ -2,5 +2,6 @@ pub mod generation;
pub mod models;
pub mod object_detection;
pub mod pipelines;
+pub mod quantized_nn;
pub mod quantized_var_builder;
pub mod utils;
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs
index 171e7440..00c80209 100644
--- a/candle-transformers/src/models/quantized_mistral.rs
+++ b/candle-transformers/src/models/quantized_mistral.rs
@@ -1,5 +1,4 @@
-use crate::models::quantized_t5::Embedding;
-use crate::models::with_tracing::QMatMul;
+use crate::quantized_nn::{linear_no_bias, Embedding, Linear, RmsNorm};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::Activation;
@@ -8,44 +7,6 @@ use std::sync::Arc;
pub use crate::models::mistral::Config;
#[derive(Debug)]
-struct Linear {
- weight: QMatMul,
-}
-
-impl Module for Linear {
- fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- x.apply(&self.weight)
- }
-}
-
-fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = QMatMul::new(in_dim, out_dim, vb)?;
- Ok(Linear { weight })
-}
-
-#[derive(Debug)]
-struct RmsNorm {
- inner: candle_nn::RmsNorm,
- span: tracing::Span,
-}
-
-impl RmsNorm {
- fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
- let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
- let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
- let inner = candle_nn::RmsNorm::new(weight, eps);
- Ok(Self { inner, span })
- }
-}
-
-impl Module for RmsNorm {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
-#[derive(Debug)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs
index f7eebb72..23eeb0ac 100644
--- a/candle-transformers/src/models/quantized_mixformer.rs
+++ b/candle-transformers/src/models/quantized_mixformer.rs
@@ -1,4 +1,4 @@
-use crate::models::with_tracing::QMatMul;
+use crate::quantized_nn::{layer_norm, linear, Linear};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::Activation;
@@ -9,12 +9,12 @@ const MAX_SEQ_LEN: usize = 4096;
#[derive(Debug)]
struct Embedding {
- wte: super::quantized_t5::Embedding,
+ wte: crate::quantized_nn::Embedding,
}
impl Embedding {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
- let wte = super::quantized_t5::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
+ let wte = crate::quantized_nn::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
Ok(Self { wte })
}
}
@@ -25,37 +25,6 @@ impl Module for Embedding {
}
}
-#[derive(Debug)]
-struct Linear {
- weight: QMatMul,
- bias: Option<Tensor>,
-}
-
-impl Module for Linear {
- fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- let x = x.apply(&self.weight)?;
- match &self.bias {
- None => Ok(x),
- Some(bias) => x.broadcast_add(bias),
- }
- }
-}
-
-fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
- let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
- let weight = QMatMul::new(in_dim, out_dim, vb)?;
- Ok(Linear {
- weight,
- bias: Some(bias),
- })
-}
-
-fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
- let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
- let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
- Ok(candle_nn::LayerNorm::new(weight, bias, eps))
-}
-
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs
index 86964237..304e91ee 100644
--- a/candle-transformers/src/models/quantized_stable_lm.rs
+++ b/candle-transformers/src/models/quantized_stable_lm.rs
@@ -1,5 +1,4 @@
-use crate::models::quantized_t5::Embedding;
-use crate::models::with_tracing::QMatMul;
+use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm};
@@ -9,28 +8,6 @@ pub use crate::models::stable_lm::Config;
use crate::models::stable_lm::RotaryEmbedding;
#[derive(Debug)]
-struct Linear {
- weight: QMatMul,
-}
-
-impl Module for Linear {
- fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- x.apply(&self.weight)
- }
-}
-
-fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = QMatMul::new(in_dim, out_dim, vb)?;
- Ok(Linear { weight })
-}
-
-fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
- let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
- let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
- Ok(candle_nn::LayerNorm::new(weight, bias, eps))
-}
-
-#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs
index 398e82a7..1426df39 100644
--- a/candle-transformers/src/models/quantized_t5.rs
+++ b/candle-transformers/src/models/quantized_t5.rs
@@ -2,38 +2,13 @@
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use crate::models::with_tracing::QMatMul;
+use crate::quantized_nn::Embedding;
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::Activation;
use serde::Deserialize;
use std::sync::Arc;
-#[derive(Debug)]
-pub struct Embedding {
- inner: candle_nn::Embedding,
- span: tracing::Span,
-}
-
-impl Embedding {
- pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
- let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
- let inner = candle_nn::Embedding::new(embeddings, d2);
- let span = tracing::span!(tracing::Level::TRACE, "embedding");
- Ok(Self { inner, span })
- }
-
- pub fn embeddings(&self) -> &Tensor {
- self.inner.embeddings()
- }
-}
-
-impl Module for Embedding {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(xs)
- }
-}
-
fn default_relative_attention_max_distance() -> usize {
128
}
diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs
index 26ec6c94..f0aead49 100644
--- a/candle-transformers/src/models/whisper/quantized_model.rs
+++ b/candle-transformers/src/models/whisper/quantized_model.rs
@@ -1,39 +1,9 @@
use super::Config;
-use crate::models::{quantized_t5::Embedding, with_tracing::QMatMul};
+use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module};
-#[derive(Debug)]
-struct Linear {
- weight: QMatMul,
- bias: Option<Tensor>,
-}
-
-impl Module for Linear {
- fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- let x = x.apply(&self.weight)?;
- match &self.bias {
- None => Ok(x),
- Some(bias) => x.broadcast_add(bias),
- }
- }
-}
-
-fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
- let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
- let weight = QMatMul::new(in_dim, out_dim, vb)?;
- Ok(Linear {
- weight,
- bias: Some(bias),
- })
-}
-
-fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = QMatMul::new(in_dim, out_dim, vb)?;
- Ok(Linear { weight, bias: None })
-}
-
fn conv1d(
in_channels: usize,
out_channels: usize,
@@ -48,12 +18,6 @@ fn conv1d(
Ok(Conv1d::new(weight, Some(bias), config))
}
-fn layer_norm(size: usize, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
- let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
- let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
- Ok(candle_nn::LayerNorm::new(weight, bias, 1e-5))
-}
-
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
struct MultiHeadAttention {
query: Linear,
@@ -178,10 +142,10 @@ impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
- let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
+ let attn_ln = layer_norm(n_state, 1e-5, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
- let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
+ let cross_attn_ln = layer_norm(n_state, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
Some((cross_attn, cross_attn_ln))
} else {
None
@@ -189,7 +153,7 @@ impl ResidualAttentionBlock {
let n_mlp = n_state * 4;
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
- let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
+ let mlp_ln = layer_norm(n_state, 1e-5, vb.pp("final_layer_norm"))?;
Ok(Self {
attn,
attn_ln,
@@ -281,7 +245,7 @@ impl AudioEncoder {
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
- let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
+ let ln_post = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?;
Ok(Self {
conv1,
conv2,
@@ -343,7 +307,7 @@ impl TextDecoder {
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
- let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
+ let ln = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?;
let mask: Vec<_> = (0..n_ctx)
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs
new file mode 100644
index 00000000..1745327d
--- /dev/null
+++ b/candle-transformers/src/quantized_nn.rs
@@ -0,0 +1,87 @@
+use crate::models::with_tracing::QMatMul;
+use crate::quantized_var_builder::VarBuilder;
+use candle::{Module, Result, Tensor};
+
+#[derive(Debug)]
+pub struct Embedding {
+ inner: candle_nn::Embedding,
+ span: tracing::Span,
+}
+
+impl Embedding {
+ pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
+ let inner = candle_nn::Embedding::new(embeddings, d2);
+ let span = tracing::span!(tracing::Level::TRACE, "embedding");
+ Ok(Self { inner, span })
+ }
+
+ pub fn embeddings(&self) -> &Tensor {
+ self.inner.embeddings()
+ }
+}
+
+impl Module for Embedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+ weight: QMatMul,
+ bias: Option<Tensor>,
+}
+
+impl Module for Linear {
+ fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
+ let x = x.apply(&self.weight)?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => x.broadcast_add(bias),
+ }
+ }
+}
+
+pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
+ let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
+ let weight = QMatMul::new(in_dim, out_dim, vb)?;
+ Ok(Linear {
+ weight,
+ bias: Some(bias),
+ })
+}
+
+pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
+ let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
+ let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
+ Ok(candle_nn::LayerNorm::new(weight, bias, eps))
+}
+
+pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = QMatMul::new(in_dim, out_dim, vb)?;
+ Ok(Linear { weight, bias: None })
+}
+
+#[derive(Debug)]
+pub struct RmsNorm {
+ inner: candle_nn::RmsNorm,
+ span: tracing::Span,
+}
+
+impl RmsNorm {
+ pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
+ let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
+ let inner = candle_nn::RmsNorm::new(weight, eps);
+ Ok(Self { inner, span })
+ }
+}
+
+impl Module for RmsNorm {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}