summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c/src/worker.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/worker.rs')
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs353
1 files changed, 353 insertions, 0 deletions
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
new file mode 100644
index 00000000..9b0351d6
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -0,0 +1,353 @@
+use crate::model::{Cache, Config, Llama};
+use byteorder::{LittleEndian, ReadBytesExt};
+use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
+use candle_nn::VarBuilder;
+use rand::{distributions::Distribution, SeedableRng};
+use serde::{Deserialize, Serialize};
+use wasm_bindgen::prelude::*;
+use yew_agent::{HandlerId, Public, WorkerLink};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string()))
+}
+
+// Communication to the worker happens through bincode, the model weights and configs are fetched
+// on the main thread and transfered via the following structure.
+#[derive(Serialize, Deserialize)]
+pub struct ModelData {
+ pub tokenizer: Vec<u8>,
+ pub model: Vec<u8>,
+}
+
+fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> {
+ let mut buf = [0u8; 4];
+ r.read_exact(&mut buf)?;
+ Ok(i32::from_le_bytes(buf))
+}
+
+fn read_tensor<R: std::io::Read, S: Into<Shape>>(
+ r: &mut R,
+ shape: S,
+ dev: &Device,
+) -> Result<Tensor> {
+ let shape = shape.into();
+ let mut data_t = vec![0f32; shape.elem_count()];
+ r.read_f32_into::<LittleEndian>(&mut data_t)?;
+ let tensor = Tensor::from_vec(data_t, shape, dev)?;
+ Ok(tensor)
+}
+
+struct Tokenizer {
+ tokens: Vec<String>,
+}
+
+impl Tokenizer {
+ fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> {
+ let mut tokens = Vec::with_capacity(c.vocab_size);
+ for _token_index in 0..c.vocab_size {
+ let token_len = read_i32(r)?;
+ let mut token = vec![0u8; token_len as usize];
+ r.read_exact(&mut token)?;
+ tokens.push(String::from_utf8_lossy(&token).into_owned())
+ }
+ Ok(Self { tokens })
+ }
+}
+
+struct Model {
+ cache: Cache,
+ config: Config,
+ llama: Llama,
+ tokenizer: Tokenizer,
+}
+
+pub struct LogitsProcessor {
+ rng: rand::rngs::StdRng,
+ temperature: Option<f64>,
+}
+
+impl LogitsProcessor {
+ pub fn new(seed: u64, temperature: Option<f64>) -> Self {
+ Self {
+ rng: rand::rngs::StdRng::seed_from_u64(seed),
+ temperature,
+ }
+ }
+
+ pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
+ let logits = logits.to_dtype(DType::F32)?;
+ let next_token = if let Some(temperature) = self.temperature {
+ let prs = (&logits / temperature)?.softmax(D::Minus1)?;
+ let prs: Vec<f32> = prs.to_vec1()?;
+ let distr =
+ rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;
+ distr.sample(&mut self.rng) as u32
+ } else {
+ let logits_v: Vec<f32> = logits.to_vec1()?;
+ logits_v
+ .iter()
+ .enumerate()
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap()
+ };
+ Ok(next_token)
+ }
+}
+
+impl Model {
+ fn run(&self, link: &WorkerLink<Worker>, id: HandlerId) -> Result<()> {
+ let dev = Device::Cpu;
+ let mut logits_processor = LogitsProcessor::new(299792458, None);
+ let mut index_pos = 0;
+ let mut tokens = vec![1u32];
+
+ for index in 0..self.config.seq_len - 10 {
+ let context_size = if self.cache.use_kv_cache && index > 0 {
+ 1
+ } else {
+ tokens.len()
+ };
+ let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
+ let input = Tensor::new(ctxt, &dev)?.unsqueeze(0)?;
+ let logits = self.llama.forward(&input, index_pos)?;
+ let logits = logits.squeeze(0)?;
+ index_pos += ctxt.len();
+
+ let next_token = logits_processor.sample(&logits)?;
+ tokens.push(next_token);
+ let token = self.tokenizer.tokens[next_token as usize].clone();
+ link.respond(id, Ok(WorkerOutput::Generated(token)));
+ }
+ Ok(())
+ }
+}
+
+impl Config {
+ fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> {
+ let dim = read_i32(r)? as usize;
+ let hidden_dim = read_i32(r)? as usize;
+ let n_layers = read_i32(r)? as usize;
+ let n_heads = read_i32(r)? as usize;
+ let n_kv_heads = read_i32(r)? as usize;
+ let vocab_size = read_i32(r)? as usize;
+ let seq_len = read_i32(r)? as usize;
+ Ok(Self {
+ dim,
+ hidden_dim,
+ n_layers,
+ n_heads,
+ n_kv_heads,
+ vocab_size,
+ seq_len,
+ norm_eps: 1e-5,
+ })
+ }
+
+ pub fn head_size(&self) -> usize {
+ self.dim / self.n_heads
+ }
+}
+
+struct TransformerWeights {
+ // token embedding table
+ token_embedding_table: Tensor, // (vocab_size, dim)
+ // weights for rmsnorms
+ rms_att_weight: Tensor, // (layer, dim) rmsnorm weights
+ rms_ffn_weight: Tensor, // (layer, dim)
+ // weights for matmuls
+ wq: Tensor, // (layer, dim, dim)
+ wk: Tensor, // (layer, dim, dim)
+ wv: Tensor, // (layer, dim, dim)
+ wo: Tensor, // (layer, dim, dim)
+ // weights for ffn
+ w1: Tensor, // (layer, hidden_dim, dim)
+ w2: Tensor, // (layer, dim, hidden_dim)
+ w3: Tensor, // (layer, hidden_dim, dim)
+ // final rmsnorm
+ rms_final_weight: Tensor, // (dim,)
+ // freq_cis for RoPE relatively positional embeddings
+ freq_cis_real: Tensor, // (seq_len, head_size/2)
+ freq_cis_imag: Tensor, // (seq_len, head_size/2)
+}
+
+impl TransformerWeights {
+ fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> {
+ let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?;
+ let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
+ let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
+ let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
+ let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?;
+ let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
+ let rms_final_weight = read_tensor(r, c.dim, dev)?;
+ let head_size = c.head_size();
+ let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
+ let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
+ Ok(Self {
+ token_embedding_table,
+ rms_att_weight,
+ wq,
+ wk,
+ wv,
+ wo,
+ rms_ffn_weight,
+ w1,
+ w2,
+ w3,
+ rms_final_weight,
+ freq_cis_real,
+ freq_cis_imag,
+ })
+ }
+
+ fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
+ let mut ws = std::collections::HashMap::new();
+ let mut insert = |name: &str, t: Tensor| {
+ ws.insert(name.to_string(), t);
+ };
+ insert("rot.freq_cis_real", self.freq_cis_real.clone());
+ insert("rot.freq_cis_imag", self.freq_cis_imag.clone());
+ insert(
+ "model.embed_tokens.weight",
+ self.token_embedding_table.clone(),
+ );
+ insert("lm_head.weight", self.token_embedding_table.clone());
+ insert("model.norm.weight", self.rms_final_weight.clone());
+ for layer in 0..cfg.n_layers {
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.q_proj.weight"),
+ self.wq.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.k_proj.weight"),
+ self.wk.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.v_proj.weight"),
+ self.wv.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.o_proj.weight"),
+ self.wo.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.mlp.gate_proj.weight"),
+ self.w1.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.mlp.down_proj.weight"),
+ self.w2.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.mlp.up_proj.weight"),
+ self.w3.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.input_layernorm.weight"),
+ self.rms_att_weight.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.post_attention_layernorm.weight"),
+ self.rms_ffn_weight.i(layer)?,
+ );
+ }
+ let vb = VarBuilder::from_tensors(ws, DType::F32, device);
+ Ok(vb)
+ }
+}
+
+impl Model {
+ fn load(md: ModelData) -> Result<Self> {
+ let dev = Device::Cpu;
+ let mut model = std::io::Cursor::new(md.model);
+ let config = Config::from_reader(&mut model)?;
+ let weights = TransformerWeights::from_reader(&mut model, &config, &dev)?;
+ let vb = weights.var_builder(&config, &dev)?;
+ let cache = Cache::new(true, &config, vb.pp("rot"))?;
+ let llama = Llama::load(vb, &cache, &config)?;
+ let mut tokenizer = std::io::Cursor::new(md.tokenizer);
+ let tokenizer = Tokenizer::from_reader(&mut tokenizer, &config)?;
+ Ok(Self {
+ cache,
+ config,
+ llama,
+ tokenizer,
+ })
+ }
+}
+
+pub struct Worker {
+ link: WorkerLink<Self>,
+ model: Option<Model>,
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerInput {
+ ModelData(ModelData),
+ Run,
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerOutput {
+ Generated(String),
+ GenerationDone(std::result::Result<(), String>),
+ WeightsLoaded,
+}
+
+impl yew_agent::Worker for Worker {
+ type Input = WorkerInput;
+ type Message = ();
+ type Output = std::result::Result<WorkerOutput, String>;
+ type Reach = Public<Self>;
+
+ fn create(link: WorkerLink<Self>) -> Self {
+ Self { link, model: None }
+ }
+
+ fn update(&mut self, _msg: Self::Message) {
+ // no messaging
+ }
+
+ fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
+ let output = match msg {
+ WorkerInput::ModelData(md) => match Model::load(md) {
+ Ok(model) => {
+ self.model = Some(model);
+ Ok(WorkerOutput::WeightsLoaded)
+ }
+ Err(err) => Err(format!("model creation error {err:?}")),
+ },
+ WorkerInput::Run => match &self.model {
+ None => Err("model has not been set yet".to_string()),
+ Some(model) => {
+ let result = model.run(&self.link, id).map_err(|e| e.to_string());
+ Ok(WorkerOutput::GenerationDone(result))
+ }
+ },
+ };
+ self.link.respond(id, output);
+ }
+
+ fn name_of_resource() -> &'static str {
+ "worker.js"
+ }
+
+ fn resource_path_is_relative() -> bool {
+ true
+ }
+}