summaryrefslogtreecommitdiff
path: root/candle-examples/examples/quantized
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-23 11:40:20 +0100
committerGitHub <noreply@github.com>2023-08-23 11:40:20 +0100
commit4ee1cf038ada55ec477dcd6496cf2aec1902775b (patch)
treed8085698f52851484da910045df3129aad27e708 /candle-examples/examples/quantized
parent0f4ff8a739facafd4b3bc9a003d4a581202b62f8 (diff)
downloadcandle-4ee1cf038ada55ec477dcd6496cf2aec1902775b.tar.gz
candle-4ee1cf038ada55ec477dcd6496cf2aec1902775b.tar.bz2
candle-4ee1cf038ada55ec477dcd6496cf2aec1902775b.zip
Get the rms epsilon from GGUF. (#565)
Diffstat (limited to 'candle-examples/examples/quantized')
-rw-r--r--candle-examples/examples/quantized/main.rs18
1 files changed, 10 insertions, 8 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index dfe81632..f91f3194 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -25,10 +25,10 @@ struct RmsNorm {
}
impl RmsNorm {
- fn new(scale: QTensor) -> Result<Self> {
+ fn new(scale: QTensor, eps: f32) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?;
- let inner = candle_nn::LayerNorm::rms_norm(scale, 1e-5);
+ let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
Ok(Self { inner, span })
}
@@ -217,7 +217,7 @@ impl ModelWeights {
let (cos, sin) = precomput_freqs_cis(head_dim)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
- let norm = RmsNorm::new(ct.remove("norm.weight")?)?;
+ let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
for layer_idx in 0..ct.hparams.n_layer {
@@ -239,11 +239,11 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
- attention_norm: RmsNorm::new(attention_norm)?,
+ attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
- ffn_norm: RmsNorm::new(ffn_norm)?,
+ ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
@@ -284,12 +284,14 @@ impl ModelWeights {
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
+ // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
+ let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;
let (cos, sin) = precomput_freqs_cis(rope_dim)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
- let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?)?;
+ let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
let output = ct.tensor(reader, "output.weight")?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
@@ -311,11 +313,11 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk),
attention_wv: QMatMul::from_qtensor(attention_wv),
attention_wo: QMatMul::from_qtensor(attention_wo),
- attention_norm: RmsNorm::new(attention_norm)?,
+ attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
- ffn_norm: RmsNorm::new(ffn_norm)?,
+ ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,