diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-17 10:07:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-17 10:07:13 +0100 |
commit | d32e8199cd6c8381aa309528675d6d6a88c0f850 (patch) | |
tree | a77b1e6910e9aec0a87813cf6a8f5120838470ce /candle-examples/examples/llama2-c | |
parent | d99cac3ec38a52bd81cc72059259729e7272e490 (diff) | |
download | candle-d32e8199cd6c8381aa309528675d6d6a88c0f850.tar.gz candle-d32e8199cd6c8381aa309528675d6d6a88c0f850.tar.bz2 candle-d32e8199cd6c8381aa309528675d6d6a88c0f850.zip |
Layer norm tweaks (#482)
* Add some options to make layer-norm more configurable.
* Add the rms-norm variant.
* Replace the RmsNorm with the shared bits.
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r-- | candle-examples/examples/llama2-c/model.rs | 42 |
1 files changed, 8 insertions, 34 deletions
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 77900d27..75269665 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -1,6 +1,6 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::linear_no_bias as linear; -use candle_nn::{embedding, Embedding, Linear, VarBuilder}; +use candle_nn::{embedding, rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -94,32 +94,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> { xs / (xs.neg()?.exp()? + 1.0)? } -struct RmsNorm { - scale: Tensor, - eps: f64, -} - -impl RmsNorm { - fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { - let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?; - Ok(Self { scale, eps }) - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let (b_sz, seq_len, hidden_size) = x.dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; - let size = self.scale.dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - Ok(x) - } -} - struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -262,14 +236,14 @@ impl Mlp { } struct Block { - rms_1: RmsNorm, + rms_1: LayerNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: LayerNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -290,9 +264,9 @@ impl Block { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = - RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; Ok(Self::new( input_layernorm, attn, @@ -305,7 +279,7 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec<Block>, - ln_f: RmsNorm, + ln_f: LayerNorm, lm_head: Linear, pub config: Config, } @@ -325,7 +299,7 @@ impl Llama { pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> { let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap()) .collect(); |