summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-18 08:52:14 +0100
committerGitHub <noreply@github.com>2023-08-18 08:52:14 +0100
commit13401df4d141bf568a2c2056411d62060707e79b (patch)
tree6c02f1465863fa5986cfdba4201cfc68c5bea533 /candle-wasm-examples/llama2-c
parenta22b1bed7bc5c11730dce857a882c1d3642c51e9 (diff)
downloadcandle-13401df4d141bf568a2c2056411d62060707e79b.tar.gz
candle-13401df4d141bf568a2c2056411d62060707e79b.tar.bz2
candle-13401df4d141bf568a2c2056411d62060707e79b.zip
Add an abstract type for RmsNorm. (#499)
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs12
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index d2b787ae..2c867793 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
+use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -213,14 +213,14 @@ impl Mlp {
}
struct Block {
- rms_1: LayerNorm,
+ rms_1: RmsNorm,
attn: CausalSelfAttention,
- rms_2: LayerNorm,
+ rms_2: RmsNorm,
mlp: Mlp,
}
impl Block {
- fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
+ fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
@@ -256,12 +256,12 @@ impl Block {
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
- ln_f: LayerNorm,
+ ln_f: RmsNorm,
lm_head: Linear,
}
impl Llama {
- fn new(wte: Embedding, blocks: Vec<Block>, ln_f: LayerNorm, lm_head: Linear) -> Self {
+ fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
Self {
wte,
blocks,