From d32e8199cd6c8381aa309528675d6d6a88c0f850 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 17 Aug 2023 10:07:13 +0100 Subject: Layer norm tweaks (#482) * Add some options to make layer-norm more configurable. * Add the rms-norm variant. * Replace the RmsNorm with the shared bits. --- candle-wasm-examples/llama2-c/src/model.rs | 44 ++++++------------------------ 1 file changed, 9 insertions(+), 35 deletions(-) (limited to 'candle-wasm-examples/llama2-c') diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 3231cabf..d2b787ae 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Linear, VarBuilder}; +use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -71,32 +71,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result { Ok(Embedding::new(embeddings, cfg.dim)) } -struct RmsNorm { - scale: Tensor, - eps: f64, -} - -impl RmsNorm { - fn load(size: usize, eps: f64, vb: VarBuilder) -> Result { - let scale = vb.get(size, "weight")?; - Ok(Self { scale, eps }) - } - - fn forward(&self, x: &Tensor) -> Result { - let (b_sz, seq_len, hidden_size) = x.dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; - let size = self.scale.dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - Ok(x) - } -} - struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -239,14 +213,14 @@ impl Mlp { } struct Block { - rms_1: RmsNorm, + rms_1: LayerNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: LayerNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -267,9 +241,9 @@ impl Block { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = - RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; Ok(Self::new( input_layernorm, attn, @@ -282,12 +256,12 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec, - ln_f: RmsNorm, + ln_f: LayerNorm, lm_head: Linear, } impl Llama { - fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { + fn new(wte: Embedding, blocks: Vec, ln_f: LayerNorm, lm_head: Linear) -> Self { Self { wte, blocks, @@ -311,7 +285,7 @@ impl Llama { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) .collect(); -- cgit v1.2.3