diff options
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/model.rs | 20 |
1 files changed, 3 insertions, 17 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 3fedb1d3..7471938a 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; +use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -57,20 +57,6 @@ impl Cache { } } -fn silu(xs: &Tensor) -> Result<Tensor> { - xs / (xs.neg()?.exp()? + 1.0)? -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), "weight")?; - Ok(Linear::new(weight, None)) -} - -fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?; - Ok(Embedding::new(embeddings, cfg.dim)) -} - struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -198,7 +184,7 @@ impl Mlp { } fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; self.c_proj.forward(&x) } @@ -283,7 +269,7 @@ impl Llama { } pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) |