summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-21 18:49:35 +0100
committerGitHub <noreply@github.com>2024-03-21 18:49:35 +0100
commitc0bdd9c7a613682ed1f2e7010374bb03621c4153 (patch)
tree60bb6c07cb0012be59a983dddd623b3122cc0ac3 /candle-transformers
parent9563a5fee42f8fef754c238e28ca79725813cea1 (diff)
downloadcandle-c0bdd9c7a613682ed1f2e7010374bb03621c4153.tar.gz
candle-c0bdd9c7a613682ed1f2e7010374bb03621c4153.tar.bz2
candle-c0bdd9c7a613682ed1f2e7010374bb03621c4153.zip
Use the fast RmsNorm in the quantized model. (#1904)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/quantized_llama.rs35
-rw-r--r--candle-transformers/src/models/quantized_mistral.rs1
-rw-r--r--candle-transformers/src/quantized_nn.rs20
3 files changed, 21 insertions, 35 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index 5ce2de59..ee50c092 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -1,5 +1,6 @@
use std::collections::HashMap;
+use crate::quantized_nn::RmsNorm;
use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
@@ -7,26 +8,6 @@ use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
-#[derive(Debug, Clone)]
-struct RmsNorm {
- inner: candle_nn::LayerNorm,
- span: tracing::Span,
-}
-
-impl RmsNorm {
- fn new(scale: QTensor, eps: f32) -> Result<Self> {
- let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
- let scale = scale.dequantize(&scale.device())?;
- let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
- Ok(Self { inner, span })
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
// QMatMul wrapper adding some tracing.
#[derive(Debug, Clone)]
struct QMatMul {
@@ -301,7 +282,7 @@ impl ModelWeights {
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
- let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
+ let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
for layer_idx in 0..ct.hparams.n_layer {
@@ -330,9 +311,9 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
- attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
+ attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
mlp_or_moe,
- ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
+ ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
@@ -381,7 +362,7 @@ impl ModelWeights {
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
- let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;
+ let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f64()?;
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
@@ -391,7 +372,7 @@ impl ModelWeights {
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
- let norm = RmsNorm::new(
+ let norm = RmsNorm::from_qtensor(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
)?;
@@ -450,9 +431,9 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
- attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
+ attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
mlp_or_moe,
- ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
+ ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs
index f2cb3b27..77de7b75 100644
--- a/candle-transformers/src/models/quantized_mistral.rs
+++ b/candle-transformers/src/models/quantized_mistral.rs
@@ -327,6 +327,7 @@ impl Model {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
+ .contiguous()?
.apply(&self.norm)?
.apply(&self.lm_head)
}
diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs
index bb0a8641..9298b80e 100644
--- a/candle-transformers/src/quantized_nn.rs
+++ b/candle-transformers/src/quantized_nn.rs
@@ -1,5 +1,6 @@
use crate::models::with_tracing::QMatMul;
use crate::quantized_var_builder::VarBuilder;
+use candle::quantized::QTensor;
use candle::{Module, Result, Tensor};
#[derive(Debug, Clone)]
@@ -35,10 +36,7 @@ pub struct Linear {
}
impl Linear {
- pub fn from_arc(
- weight: std::sync::Arc<candle::quantized::QTensor>,
- bias: Option<Tensor>,
- ) -> Result<Self> {
+ pub fn from_arc(weight: std::sync::Arc<QTensor>, bias: Option<Tensor>) -> Result<Self> {
let weight = QMatMul::from_weights(weight)?;
Ok(Self { weight, bias })
}
@@ -95,7 +93,8 @@ pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<L
#[derive(Debug, Clone)]
pub struct RmsNorm {
- inner: candle_nn::RmsNorm,
+ weight: Tensor,
+ eps: f64,
span: tracing::Span,
}
@@ -103,14 +102,19 @@ impl RmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
- let inner = candle_nn::RmsNorm::new(weight, eps);
- Ok(Self { inner, span })
+ Ok(Self { weight, eps, span })
+ }
+
+ pub fn from_qtensor(weight: QTensor, eps: f64) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
+ let weight = weight.dequantize(&weight.device())?;
+ Ok(Self { weight, eps, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
- self.inner.forward(x)
+ candle_nn::ops::rms_norm(x, &self.weight, self.eps as f32)
}
}