diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-27 14:08:29 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-27 14:08:29 +0100 |
commit | 916619f70bfae089597ce421e19a3b2e85c2d27b (patch) | |
tree | d38e7546128f877ac6fa5ffefd99adfd4130de01 /candle-wasm-examples/llama2-c | |
parent | 9b1158b3158dae2eafb91e9da126f66bf9e111d6 (diff) | |
download | candle-916619f70bfae089597ce421e19a3b2e85c2d27b.tar.gz candle-916619f70bfae089597ce421e19a3b2e85c2d27b.tar.bz2 candle-916619f70bfae089597ce421e19a3b2e85c2d27b.zip |
Minor cleanup (#1194)
* Add some missing backtraces.
* Small cleanup.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/model.rs | 20 |
1 files changed, 3 insertions, 17 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 3fedb1d3..7471938a 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; +use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -57,20 +57,6 @@ impl Cache { } } -fn silu(xs: &Tensor) -> Result<Tensor> { - xs / (xs.neg()?.exp()? + 1.0)? -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), "weight")?; - Ok(Linear::new(weight, None)) -} - -fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?; - Ok(Embedding::new(embeddings, cfg.dim)) -} - struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -198,7 +184,7 @@ impl Mlp { } fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; self.c_proj.forward(&x) } @@ -283,7 +269,7 @@ impl Llama { } pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) |