summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs20
1 files changed, 3 insertions, 17 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index 3fedb1d3..7471938a 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
+use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -57,20 +57,6 @@ impl Cache {
}
}
-fn silu(xs: &Tensor) -> Result<Tensor> {
- xs / (xs.neg()?.exp()? + 1.0)?
-}
-
-fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- Ok(Linear::new(weight, None))
-}
-
-fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?;
- Ok(Embedding::new(embeddings, cfg.dim))
-}
-
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@@ -198,7 +184,7 @@ impl Mlp {
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
+ let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
@@ -283,7 +269,7 @@ impl Llama {
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
- let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
+ let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)