diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-09 06:22:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-09 06:22:22 +0100 |
commit | 392fe02fba96658bafc73100e80bf68d54e4e23f (patch) | |
tree | 3c3f9ef5e663a374011c1c90bec8e0e2b6bb30f8 | |
parent | 59ab6d7832600083a1519aa0511e9c7c832ae01c (diff) | |
download | candle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.gz candle-392fe02fba96658bafc73100e80bf68d54e4e23f.tar.bz2 candle-392fe02fba96658bafc73100e80bf68d54e4e23f.zip |
Move the common quantized-nn code to a shared module. (#1063)
-rw-r--r-- | candle-transformers/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_mistral.rs | 41 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_mixformer.rs | 37 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_stable_lm.rs | 25 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_t5.rs | 27 | ||||
-rw-r--r-- | candle-transformers/src/models/whisper/quantized_model.rs | 48 | ||||
-rw-r--r-- | candle-transformers/src/quantized_nn.rs | 87 |
7 files changed, 100 insertions, 166 deletions
diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs index a4c7ddf7..b2b062a9 100644 --- a/candle-transformers/src/lib.rs +++ b/candle-transformers/src/lib.rs @@ -2,5 +2,6 @@ pub mod generation; pub mod models; pub mod object_detection; pub mod pipelines; +pub mod quantized_nn; pub mod quantized_var_builder; pub mod utils; diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 171e7440..00c80209 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -1,5 +1,4 @@ -use crate::models::quantized_t5::Embedding; -use crate::models::with_tracing::QMatMul; +use crate::quantized_nn::{linear_no_bias, Embedding, Linear, RmsNorm}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::Activation; @@ -8,44 +7,6 @@ use std::sync::Arc; pub use crate::models::mistral::Config; #[derive(Debug)] -struct Linear { - weight: QMatMul, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - x.apply(&self.weight) - } -} - -fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { - let weight = QMatMul::new(in_dim, out_dim, vb)?; - Ok(Linear { weight }) -} - -#[derive(Debug)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let weight = vb.get(size, "weight")?.dequantize(vb.device())?; - let inner = candle_nn::RmsNorm::new(weight, eps); - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -#[derive(Debug)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index f7eebb72..23eeb0ac 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -1,4 +1,4 @@ -use crate::models::with_tracing::QMatMul; +use crate::quantized_nn::{layer_norm, linear, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::Activation; @@ -9,12 +9,12 @@ const MAX_SEQ_LEN: usize = 4096; #[derive(Debug)] struct Embedding { - wte: super::quantized_t5::Embedding, + wte: crate::quantized_nn::Embedding, } impl Embedding { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { - let wte = super::quantized_t5::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?; + let wte = crate::quantized_nn::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?; Ok(Self { wte }) } } @@ -25,37 +25,6 @@ impl Module for Embedding { } } -#[derive(Debug)] -struct Linear { - weight: QMatMul, - bias: Option<Tensor>, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - let x = x.apply(&self.weight)?; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } - } -} - -fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { - let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; - let weight = QMatMul::new(in_dim, out_dim, vb)?; - Ok(Linear { - weight, - bias: Some(bias), - }) -} - -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> { - let weight = vb.get(size, "weight")?.dequantize(vb.device())?; - let bias = vb.get(size, "bias")?.dequantize(vb.device())?; - Ok(candle_nn::LayerNorm::new(weight, bias, eps)) -} - fn get_mask(size: usize, device: &Device) -> Result<Tensor> { let mask: Vec<_> = (0..size) .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index 86964237..304e91ee 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -1,5 +1,4 @@ -use crate::models::quantized_t5::Embedding; -use crate::models::with_tracing::QMatMul; +use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm}; @@ -9,28 +8,6 @@ pub use crate::models::stable_lm::Config; use crate::models::stable_lm::RotaryEmbedding; #[derive(Debug)] -struct Linear { - weight: QMatMul, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - x.apply(&self.weight) - } -} - -fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { - let weight = QMatMul::new(in_dim, out_dim, vb)?; - Ok(Linear { weight }) -} - -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> { - let weight = vb.get(size, "weight")?.dequantize(vb.device())?; - let bias = vb.get(size, "bias")?.dequantize(vb.device())?; - Ok(candle_nn::LayerNorm::new(weight, bias, eps)) -} - -#[derive(Debug)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Linear, diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 398e82a7..1426df39 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -2,38 +2,13 @@ // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py use crate::models::with_tracing::QMatMul; +use crate::quantized_nn::Embedding; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::Activation; use serde::Deserialize; use std::sync::Arc; -#[derive(Debug)] -pub struct Embedding { - inner: candle_nn::Embedding, - span: tracing::Span, -} - -impl Embedding { - pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> { - let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?; - let inner = candle_nn::Embedding::new(embeddings, d2); - let span = tracing::span!(tracing::Level::TRACE, "embedding"); - Ok(Self { inner, span }) - } - - pub fn embeddings(&self) -> &Tensor { - self.inner.embeddings() - } -} - -impl Module for Embedding { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(xs) - } -} - fn default_relative_attention_max_distance() -> usize { 128 } diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs index 26ec6c94..f0aead49 100644 --- a/candle-transformers/src/models/whisper/quantized_model.rs +++ b/candle-transformers/src/models/whisper/quantized_model.rs @@ -1,39 +1,9 @@ use super::Config; -use crate::models::{quantized_t5::Embedding, with_tracing::QMatMul}; +use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module}; -#[derive(Debug)] -struct Linear { - weight: QMatMul, - bias: Option<Tensor>, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - let x = x.apply(&self.weight)?; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } - } -} - -fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { - let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; - let weight = QMatMul::new(in_dim, out_dim, vb)?; - Ok(Linear { - weight, - bias: Some(bias), - }) -} - -fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { - let weight = QMatMul::new(in_dim, out_dim, vb)?; - Ok(Linear { weight, bias: None }) -} - fn conv1d( in_channels: usize, out_channels: usize, @@ -48,12 +18,6 @@ fn conv1d( Ok(Conv1d::new(weight, Some(bias), config)) } -fn layer_norm(size: usize, vb: VarBuilder) -> Result<candle_nn::LayerNorm> { - let weight = vb.get(size, "weight")?.dequantize(vb.device())?; - let bias = vb.get(size, "bias")?.dequantize(vb.device())?; - Ok(candle_nn::LayerNorm::new(weight, bias, 1e-5)) -} - // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 struct MultiHeadAttention { query: Linear, @@ -178,10 +142,10 @@ impl ResidualAttentionBlock { fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; - let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; + let attn_ln = layer_norm(n_state, 1e-5, vb.pp("self_attn_layer_norm"))?; let cross_attn = if ca { let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; - let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; + let cross_attn_ln = layer_norm(n_state, 1e-5, vb.pp("encoder_attn_layer_norm"))?; Some((cross_attn, cross_attn_ln)) } else { None @@ -189,7 +153,7 @@ impl ResidualAttentionBlock { let n_mlp = n_state * 4; let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; - let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; + let mlp_ln = layer_norm(n_state, 1e-5, vb.pp("final_layer_norm"))?; Ok(Self { attn, attn_ln, @@ -281,7 +245,7 @@ impl AudioEncoder { ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}"))) }) .collect::<Result<Vec<_>>>()?; - let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; + let ln_post = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?; Ok(Self { conv1, conv2, @@ -343,7 +307,7 @@ impl TextDecoder { ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}"))) }) .collect::<Result<Vec<_>>>()?; - let ln = layer_norm(n_state, vb.pp("layer_norm"))?; + let ln = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?; let mask: Vec<_> = (0..n_ctx) .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) .collect(); diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs new file mode 100644 index 00000000..1745327d --- /dev/null +++ b/candle-transformers/src/quantized_nn.rs @@ -0,0 +1,87 @@ +use crate::models::with_tracing::QMatMul; +use crate::quantized_var_builder::VarBuilder; +use candle::{Module, Result, Tensor}; + +#[derive(Debug)] +pub struct Embedding { + inner: candle_nn::Embedding, + span: tracing::Span, +} + +impl Embedding { + pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> { + let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?; + let inner = candle_nn::Embedding::new(embeddings, d2); + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + + pub fn embeddings(&self) -> &Tensor { + self.inner.embeddings() + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +#[derive(Debug)] +pub struct Linear { + weight: QMatMul, + bias: Option<Tensor>, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { + let x = x.apply(&self.weight)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { + let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { + weight, + bias: Some(bias), + }) +} + +pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> { + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + let bias = vb.get(size, "bias")?.dequantize(vb.device())?; + Ok(candle_nn::LayerNorm::new(weight, bias, eps)) +} + +pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { weight, bias: None }) +} + +#[derive(Debug)] +pub struct RmsNorm { + inner: candle_nn::RmsNorm, + span: tracing::Span, +} + +impl RmsNorm { + pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + let inner = candle_nn::RmsNorm::new(weight, eps); + Ok(Self { inner, span }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} |