diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-18 08:52:14 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-18 08:52:14 +0100 |
commit | 13401df4d141bf568a2c2056411d62060707e79b (patch) | |
tree | 6c02f1465863fa5986cfdba4201cfc68c5bea533 /candle-wasm-examples/llama2-c | |
parent | a22b1bed7bc5c11730dce857a882c1d3642c51e9 (diff) | |
download | candle-13401df4d141bf568a2c2056411d62060707e79b.tar.gz candle-13401df4d141bf568a2c2056411d62060707e79b.tar.bz2 candle-13401df4d141bf568a2c2056411d62060707e79b.zip |
Add an abstract type for RmsNorm. (#499)
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/model.rs | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index d2b787ae..2c867793 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -213,14 +213,14 @@ impl Mlp { } struct Block { - rms_1: LayerNorm, + rms_1: RmsNorm, attn: CausalSelfAttention, - rms_2: LayerNorm, + rms_2: RmsNorm, mlp: Mlp, } impl Block { - fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -256,12 +256,12 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec<Block>, - ln_f: LayerNorm, + ln_f: RmsNorm, lm_head: Linear, } impl Llama { - fn new(wte: Embedding, blocks: Vec<Block>, ln_f: LayerNorm, lm_head: Linear) -> Self { + fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self { Self { wte, blocks, |