summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-27 14:08:29 +0100
committerGitHub <noreply@github.com>2023-10-27 14:08:29 +0100
commit916619f70bfae089597ce421e19a3b2e85c2d27b (patch)
treed38e7546128f877ac6fa5ffefd99adfd4130de01 /candle-wasm-examples/llama2-c
parent9b1158b3158dae2eafb91e9da126f66bf9e111d6 (diff)
downloadcandle-916619f70bfae089597ce421e19a3b2e85c2d27b.tar.gz
candle-916619f70bfae089597ce421e19a3b2e85c2d27b.tar.bz2
candle-916619f70bfae089597ce421e19a3b2e85c2d27b.zip
Minor cleanup (#1194)
* Add some missing backtraces. * Small cleanup.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs20
1 files changed, 3 insertions, 17 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index 3fedb1d3..7471938a 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
+use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -57,20 +57,6 @@ impl Cache {
}
}
-fn silu(xs: &Tensor) -> Result<Tensor> {
- xs / (xs.neg()?.exp()? + 1.0)?
-}
-
-fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- Ok(Linear::new(weight, None))
-}
-
-fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?;
- Ok(Embedding::new(embeddings, cfg.dim))
-}
-
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@@ -198,7 +184,7 @@ impl Mlp {
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
+ let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
@@ -283,7 +269,7 @@ impl Llama {
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
- let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
+ let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)