summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-17 10:07:13 +0100
committerGitHub <noreply@github.com>2023-08-17 10:07:13 +0100
commitd32e8199cd6c8381aa309528675d6d6a88c0f850 (patch)
treea77b1e6910e9aec0a87813cf6a8f5120838470ce /candle-wasm-examples/llama2-c
parentd99cac3ec38a52bd81cc72059259729e7272e490 (diff)
downloadcandle-d32e8199cd6c8381aa309528675d6d6a88c0f850.tar.gz
candle-d32e8199cd6c8381aa309528675d6d6a88c0f850.tar.bz2
candle-d32e8199cd6c8381aa309528675d6d6a88c0f850.zip
Layer norm tweaks (#482)
* Add some options to make layer-norm more configurable. * Add the rms-norm variant. * Replace the RmsNorm with the shared bits.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs44
1 files changed, 9 insertions, 35 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index 3231cabf..d2b787ae 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, Linear, VarBuilder};
+use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -71,32 +71,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
Ok(Embedding::new(embeddings, cfg.dim))
}
-struct RmsNorm {
- scale: Tensor,
- eps: f64,
-}
-
-impl RmsNorm {
- fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
- let scale = vb.get(size, "weight")?;
- Ok(Self { scale, eps })
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let (b_sz, seq_len, hidden_size) = x.dims3()?;
- let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
- let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
- let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
- let size = self.scale.dims1()?;
- let scale = self
- .scale
- .to_dtype(DType::F32)?
- .broadcast_as((b_sz, seq_len, size))?;
- let x = (scale * x_normed)?;
- Ok(x)
- }
-}
-
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@@ -239,14 +213,14 @@ impl Mlp {
}
struct Block {
- rms_1: RmsNorm,
+ rms_1: LayerNorm,
attn: CausalSelfAttention,
- rms_2: RmsNorm,
+ rms_2: LayerNorm,
mlp: Mlp,
}
impl Block {
- fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
+ fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
@@ -267,9 +241,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
- let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
+ let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
- RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
+ rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
@@ -282,12 +256,12 @@ impl Block {
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
- ln_f: RmsNorm,
+ ln_f: LayerNorm,
lm_head: Linear,
}
impl Llama {
- fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
+ fn new(wte: Embedding, blocks: Vec<Block>, ln_f: LayerNorm, lm_head: Linear) -> Self {
Self {
wte,
blocks,
@@ -311,7 +285,7 @@ impl Llama {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
- let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
+ let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
.collect();