diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-17 10:07:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-17 10:07:13 +0100 |
commit | d32e8199cd6c8381aa309528675d6d6a88c0f850 (patch) | |
tree | a77b1e6910e9aec0a87813cf6a8f5120838470ce /candle-wasm-examples/llama2-c | |
parent | d99cac3ec38a52bd81cc72059259729e7272e490 (diff) | |
download | candle-d32e8199cd6c8381aa309528675d6d6a88c0f850.tar.gz candle-d32e8199cd6c8381aa309528675d6d6a88c0f850.tar.bz2 candle-d32e8199cd6c8381aa309528675d6d6a88c0f850.zip |
Layer norm tweaks (#482)
* Add some options to make layer-norm more configurable.
* Add the rms-norm variant.
* Replace the RmsNorm with the shared bits.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/model.rs | 44 |
1 files changed, 9 insertions, 35 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 3231cabf..d2b787ae 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Linear, VarBuilder}; +use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -71,32 +71,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { Ok(Embedding::new(embeddings, cfg.dim)) } -struct RmsNorm { - scale: Tensor, - eps: f64, -} - -impl RmsNorm { - fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { - let scale = vb.get(size, "weight")?; - Ok(Self { scale, eps }) - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let (b_sz, seq_len, hidden_size) = x.dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; - let size = self.scale.dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - Ok(x) - } -} - struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -239,14 +213,14 @@ impl Mlp { } struct Block { - rms_1: RmsNorm, + rms_1: LayerNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: LayerNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -267,9 +241,9 @@ impl Block { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = - RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; Ok(Self::new( input_layernorm, attn, @@ -282,12 +256,12 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec<Block>, - ln_f: RmsNorm, + ln_f: LayerNorm, lm_head: Linear, } impl Llama { - fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self { + fn new(wte: Embedding, blocks: Vec<Block>, ln_f: LayerNorm, lm_head: Linear) -> Self { Self { wte, blocks, @@ -311,7 +285,7 @@ impl Llama { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) .collect(); |