diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-23 11:40:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-23 11:40:20 +0100 |
commit | 4ee1cf038ada55ec477dcd6496cf2aec1902775b (patch) | |
tree | d8085698f52851484da910045df3129aad27e708 /candle-examples/examples/quantized | |
parent | 0f4ff8a739facafd4b3bc9a003d4a581202b62f8 (diff) | |
download | candle-4ee1cf038ada55ec477dcd6496cf2aec1902775b.tar.gz candle-4ee1cf038ada55ec477dcd6496cf2aec1902775b.tar.bz2 candle-4ee1cf038ada55ec477dcd6496cf2aec1902775b.zip |
Get the rms epsilon from GGUF. (#565)
Diffstat (limited to 'candle-examples/examples/quantized')
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index dfe81632..f91f3194 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -25,10 +25,10 @@ struct RmsNorm { } impl RmsNorm { - fn new(scale: QTensor) -> Result<Self> { + fn new(scale: QTensor, eps: f32) -> Result<Self> { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let scale = scale.dequantize(&Device::Cpu)?; - let inner = candle_nn::LayerNorm::rms_norm(scale, 1e-5); + let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); Ok(Self { inner, span }) } @@ -217,7 +217,7 @@ impl ModelWeights { let (cos, sin) = precomput_freqs_cis(head_dim)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.remove("norm.weight")?)?; + let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; let output = ct.remove("output.weight")?; let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); for layer_idx in 0..ct.hparams.n_layer { @@ -239,11 +239,11 @@ impl ModelWeights { attention_wk: QMatMul::from_qtensor(attention_wk), attention_wv: QMatMul::from_qtensor(attention_wv), attention_wo: QMatMul::from_qtensor(attention_wo), - attention_norm: RmsNorm::new(attention_norm)?, + attention_norm: RmsNorm::new(attention_norm, 1e-5)?, feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), - ffn_norm: RmsNorm::new(ffn_norm)?, + ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, n_head: ct.hparams.n_head as usize, n_kv_head: ct.hparams.n_head as usize / gqa, head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, @@ -284,12 +284,14 @@ impl ModelWeights { let block_count = md_get("llama.block_count")?.to_u32()? as usize; let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. + let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; let (cos, sin) = precomput_freqs_cis(rope_dim)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?)?; + let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; let output = ct.tensor(reader, "output.weight")?; let mut layers = Vec::with_capacity(block_count); for layer_idx in 0..block_count { @@ -311,11 +313,11 @@ impl ModelWeights { attention_wk: QMatMul::from_qtensor(attention_wk), attention_wv: QMatMul::from_qtensor(attention_wv), attention_wo: QMatMul::from_qtensor(attention_wo), - attention_norm: RmsNorm::new(attention_norm)?, + attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), - ffn_norm: RmsNorm::new(ffn_norm)?, + ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, n_head: head_count, n_kv_head: head_count_kv, head_dim: embedding_length / head_count, |