From 13401df4d141bf568a2c2056411d62060707e79b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 18 Aug 2023 08:52:14 +0100 Subject: Add an abstract type for RmsNorm. (#499) --- candle-wasm-examples/llama2-c/src/model.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'candle-wasm-examples/llama2-c') diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index d2b787ae..2c867793 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -213,14 +213,14 @@ impl Mlp { } struct Block { - rms_1: LayerNorm, + rms_1: RmsNorm, attn: CausalSelfAttention, - rms_2: LayerNorm, + rms_2: RmsNorm, mlp: Mlp, } impl Block { - fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -256,12 +256,12 @@ impl Block { pub struct Llama { wte: Embedding, blocks: Vec, - ln_f: LayerNorm, + ln_f: RmsNorm, lm_head: Linear, } impl Llama { - fn new(wte: Embedding, blocks: Vec, ln_f: LayerNorm, lm_head: Linear) -> Self { + fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { Self { wte, blocks, -- cgit v1.2.3