diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-21 18:49:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-21 18:49:35 +0100 |
commit | c0bdd9c7a613682ed1f2e7010374bb03621c4153 (patch) | |
tree | 60bb6c07cb0012be59a983dddd623b3122cc0ac3 /candle-transformers | |
parent | 9563a5fee42f8fef754c238e28ca79725813cea1 (diff) | |
download | candle-c0bdd9c7a613682ed1f2e7010374bb03621c4153.tar.gz candle-c0bdd9c7a613682ed1f2e7010374bb03621c4153.tar.bz2 candle-c0bdd9c7a613682ed1f2e7010374bb03621c4153.zip |
Use the fast RmsNorm in the quantized model. (#1904)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/quantized_llama.rs | 35 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_mistral.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/quantized_nn.rs | 20 |
3 files changed, 21 insertions, 35 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 5ce2de59..ee50c092 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use crate::quantized_nn::RmsNorm; use candle::quantized::QTensor; use candle::quantized::{ggml_file, gguf_file}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; @@ -7,26 +8,6 @@ use candle_nn::{Embedding, Module}; pub const MAX_SEQ_LEN: usize = 4096; -#[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::LayerNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(scale: QTensor, eps: f32) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let scale = scale.dequantize(&scale.device())?; - let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); - Ok(Self { inner, span }) - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - // QMatMul wrapper adding some tracing. #[derive(Debug, Clone)] struct QMatMul { @@ -301,7 +282,7 @@ impl ModelWeights { let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; - let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; + let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?; let output = ct.remove("output.weight")?; let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); for layer_idx in 0..ct.hparams.n_layer { @@ -330,9 +311,9 @@ impl ModelWeights { attention_wk: QMatMul::from_qtensor(attention_wk)?, attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, - attention_norm: RmsNorm::new(attention_norm, 1e-5)?, + attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?, mlp_or_moe, - ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, + ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?, n_head: ct.hparams.n_head as usize, n_kv_head: ct.hparams.n_head as usize / gqa, head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, @@ -381,7 +362,7 @@ impl ModelWeights { let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; + let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f64()?; let rope_freq_base = md_get("llama.rope.freq_base") .and_then(|m| m.to_f32()) @@ -391,7 +372,7 @@ impl ModelWeights { let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let norm = RmsNorm::new( + let norm = RmsNorm::from_qtensor( ct.tensor(reader, "output_norm.weight", device)?, rms_norm_eps, )?; @@ -450,9 +431,9 @@ impl ModelWeights { attention_wk: QMatMul::from_qtensor(attention_wk)?, attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, - attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, + attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?, mlp_or_moe, - ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, + ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?, n_head: head_count, n_kv_head: head_count_kv, head_dim: embedding_length / head_count, diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index f2cb3b27..77de7b75 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -327,6 +327,7 @@ impl Model { xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? } xs.narrow(1, seq_len - 1, 1)? + .contiguous()? .apply(&self.norm)? .apply(&self.lm_head) } diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs index bb0a8641..9298b80e 100644 --- a/candle-transformers/src/quantized_nn.rs +++ b/candle-transformers/src/quantized_nn.rs @@ -1,5 +1,6 @@ use crate::models::with_tracing::QMatMul; use crate::quantized_var_builder::VarBuilder; +use candle::quantized::QTensor; use candle::{Module, Result, Tensor}; #[derive(Debug, Clone)] @@ -35,10 +36,7 @@ pub struct Linear { } impl Linear { - pub fn from_arc( - weight: std::sync::Arc<candle::quantized::QTensor>, - bias: Option<Tensor>, - ) -> Result<Self> { + pub fn from_arc(weight: std::sync::Arc<QTensor>, bias: Option<Tensor>) -> Result<Self> { let weight = QMatMul::from_weights(weight)?; Ok(Self { weight, bias }) } @@ -95,7 +93,8 @@ pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<L #[derive(Debug, Clone)] pub struct RmsNorm { - inner: candle_nn::RmsNorm, + weight: Tensor, + eps: f64, span: tracing::Span, } @@ -103,14 +102,19 @@ impl RmsNorm { pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let weight = vb.get(size, "weight")?.dequantize(vb.device())?; - let inner = candle_nn::RmsNorm::new(weight, eps); - Ok(Self { inner, span }) + Ok(Self { weight, eps, span }) + } + + pub fn from_qtensor(weight: QTensor, eps: f64) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let weight = weight.dequantize(&weight.device())?; + Ok(Self { weight, eps, span }) } } impl Module for RmsNorm { fn forward(&self, x: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); - self.inner.forward(x) + candle_nn::ops::rms_norm(x, &self.weight, self.eps as f32) } } |