diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-24 12:36:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-24 12:36:02 +0100 |
commit | 5a26cba7339e326eaca7a10ee99f6af948da2677 (patch) | |
tree | e7ce4f569f3d620bd73c0bbb00198031345723b2 /candle-wasm-examples/llama2-c/src/worker.rs | |
parent | 550a13a5472fd3aa3975c2453eff4bff6ac1d0bd (diff) | |
download | candle-5a26cba7339e326eaca7a10ee99f6af948da2677.tar.gz candle-5a26cba7339e326eaca7a10ee99f6af948da2677.tar.bz2 candle-5a26cba7339e326eaca7a10ee99f6af948da2677.zip |
Re-organize the wasm examples (#231)
* Move the whisper example.
* More renaming.
* Add llama2 as a new wasm example.
* Live generation.
* More of the llama wasm example.
* Formatting.
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/worker.rs')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 353 |
1 files changed, 353 insertions, 0 deletions
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs new file mode 100644 index 00000000..9b0351d6 --- /dev/null +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -0,0 +1,353 @@ +use crate::model::{Cache, Config, Llama}; +use byteorder::{LittleEndian, ReadBytesExt}; +use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D}; +use candle_nn::VarBuilder; +use rand::{distributions::Distribution, SeedableRng}; +use serde::{Deserialize, Serialize}; +use wasm_bindgen::prelude::*; +use yew_agent::{HandlerId, Public, WorkerLink}; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string())) +} + +// Communication to the worker happens through bincode, the model weights and configs are fetched +// on the main thread and transfered via the following structure. +#[derive(Serialize, Deserialize)] +pub struct ModelData { + pub tokenizer: Vec<u8>, + pub model: Vec<u8>, +} + +fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> { + let mut buf = [0u8; 4]; + r.read_exact(&mut buf)?; + Ok(i32::from_le_bytes(buf)) +} + +fn read_tensor<R: std::io::Read, S: Into<Shape>>( + r: &mut R, + shape: S, + dev: &Device, +) -> Result<Tensor> { + let shape = shape.into(); + let mut data_t = vec![0f32; shape.elem_count()]; + r.read_f32_into::<LittleEndian>(&mut data_t)?; + let tensor = Tensor::from_vec(data_t, shape, dev)?; + Ok(tensor) +} + +struct Tokenizer { + tokens: Vec<String>, +} + +impl Tokenizer { + fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> { + let mut tokens = Vec::with_capacity(c.vocab_size); + for _token_index in 0..c.vocab_size { + let token_len = read_i32(r)?; + let mut token = vec![0u8; token_len as usize]; + r.read_exact(&mut token)?; + tokens.push(String::from_utf8_lossy(&token).into_owned()) + } + Ok(Self { tokens }) + } +} + +struct Model { + cache: Cache, + config: Config, + llama: Llama, + tokenizer: Tokenizer, +} + +pub struct LogitsProcessor { + rng: rand::rngs::StdRng, + temperature: Option<f64>, +} + +impl LogitsProcessor { + pub fn new(seed: u64, temperature: Option<f64>) -> Self { + Self { + rng: rand::rngs::StdRng::seed_from_u64(seed), + temperature, + } + } + + pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { + let logits = logits.to_dtype(DType::F32)?; + let next_token = if let Some(temperature) = self.temperature { + let prs = (&logits / temperature)?.softmax(D::Minus1)?; + let prs: Vec<f32> = prs.to_vec1()?; + let distr = + rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?; + distr.sample(&mut self.rng) as u32 + } else { + let logits_v: Vec<f32> = logits.to_vec1()?; + logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap() + }; + Ok(next_token) + } +} + +impl Model { + fn run(&self, link: &WorkerLink<Worker>, id: HandlerId) -> Result<()> { + let dev = Device::Cpu; + let mut logits_processor = LogitsProcessor::new(299792458, None); + let mut index_pos = 0; + let mut tokens = vec![1u32]; + + for index in 0..self.config.seq_len - 10 { + let context_size = if self.cache.use_kv_cache && index > 0 { + 1 + } else { + tokens.len() + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &dev)?.unsqueeze(0)?; + let logits = self.llama.forward(&input, index_pos)?; + let logits = logits.squeeze(0)?; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + tokens.push(next_token); + let token = self.tokenizer.tokens[next_token as usize].clone(); + link.respond(id, Ok(WorkerOutput::Generated(token))); + } + Ok(()) + } +} + +impl Config { + fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> { + let dim = read_i32(r)? as usize; + let hidden_dim = read_i32(r)? as usize; + let n_layers = read_i32(r)? as usize; + let n_heads = read_i32(r)? as usize; + let n_kv_heads = read_i32(r)? as usize; + let vocab_size = read_i32(r)? as usize; + let seq_len = read_i32(r)? as usize; + Ok(Self { + dim, + hidden_dim, + n_layers, + n_heads, + n_kv_heads, + vocab_size, + seq_len, + norm_eps: 1e-5, + }) + } + + pub fn head_size(&self) -> usize { + self.dim / self.n_heads + } +} + +struct TransformerWeights { + // token embedding table + token_embedding_table: Tensor, // (vocab_size, dim) + // weights for rmsnorms + rms_att_weight: Tensor, // (layer, dim) rmsnorm weights + rms_ffn_weight: Tensor, // (layer, dim) + // weights for matmuls + wq: Tensor, // (layer, dim, dim) + wk: Tensor, // (layer, dim, dim) + wv: Tensor, // (layer, dim, dim) + wo: Tensor, // (layer, dim, dim) + // weights for ffn + w1: Tensor, // (layer, hidden_dim, dim) + w2: Tensor, // (layer, dim, hidden_dim) + w3: Tensor, // (layer, hidden_dim, dim) + // final rmsnorm + rms_final_weight: Tensor, // (dim,) + // freq_cis for RoPE relatively positional embeddings + freq_cis_real: Tensor, // (seq_len, head_size/2) + freq_cis_imag: Tensor, // (seq_len, head_size/2) +} + +impl TransformerWeights { + fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> { + let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?; + let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; + let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; + let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; + let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?; + let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; + let rms_final_weight = read_tensor(r, c.dim, dev)?; + let head_size = c.head_size(); + let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?; + let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?; + Ok(Self { + token_embedding_table, + rms_att_weight, + wq, + wk, + wv, + wo, + rms_ffn_weight, + w1, + w2, + w3, + rms_final_weight, + freq_cis_real, + freq_cis_imag, + }) + } + + fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> { + let mut ws = std::collections::HashMap::new(); + let mut insert = |name: &str, t: Tensor| { + ws.insert(name.to_string(), t); + }; + insert("rot.freq_cis_real", self.freq_cis_real.clone()); + insert("rot.freq_cis_imag", self.freq_cis_imag.clone()); + insert( + "model.embed_tokens.weight", + self.token_embedding_table.clone(), + ); + insert("lm_head.weight", self.token_embedding_table.clone()); + insert("model.norm.weight", self.rms_final_weight.clone()); + for layer in 0..cfg.n_layers { + ws.insert( + format!("model.layers.{layer}.self_attn.q_proj.weight"), + self.wq.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.self_attn.k_proj.weight"), + self.wk.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.self_attn.v_proj.weight"), + self.wv.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.self_attn.o_proj.weight"), + self.wo.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.mlp.gate_proj.weight"), + self.w1.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.mlp.down_proj.weight"), + self.w2.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.mlp.up_proj.weight"), + self.w3.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.input_layernorm.weight"), + self.rms_att_weight.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.post_attention_layernorm.weight"), + self.rms_ffn_weight.i(layer)?, + ); + } + let vb = VarBuilder::from_tensors(ws, DType::F32, device); + Ok(vb) + } +} + +impl Model { + fn load(md: ModelData) -> Result<Self> { + let dev = Device::Cpu; + let mut model = std::io::Cursor::new(md.model); + let config = Config::from_reader(&mut model)?; + let weights = TransformerWeights::from_reader(&mut model, &config, &dev)?; + let vb = weights.var_builder(&config, &dev)?; + let cache = Cache::new(true, &config, vb.pp("rot"))?; + let llama = Llama::load(vb, &cache, &config)?; + let mut tokenizer = std::io::Cursor::new(md.tokenizer); + let tokenizer = Tokenizer::from_reader(&mut tokenizer, &config)?; + Ok(Self { + cache, + config, + llama, + tokenizer, + }) + } +} + +pub struct Worker { + link: WorkerLink<Self>, + model: Option<Model>, +} + +#[derive(Serialize, Deserialize)] +pub enum WorkerInput { + ModelData(ModelData), + Run, +} + +#[derive(Serialize, Deserialize)] +pub enum WorkerOutput { + Generated(String), + GenerationDone(std::result::Result<(), String>), + WeightsLoaded, +} + +impl yew_agent::Worker for Worker { + type Input = WorkerInput; + type Message = (); + type Output = std::result::Result<WorkerOutput, String>; + type Reach = Public<Self>; + + fn create(link: WorkerLink<Self>) -> Self { + Self { link, model: None } + } + + fn update(&mut self, _msg: Self::Message) { + // no messaging + } + + fn handle_input(&mut self, msg: Self::Input, id: HandlerId) { + let output = match msg { + WorkerInput::ModelData(md) => match Model::load(md) { + Ok(model) => { + self.model = Some(model); + Ok(WorkerOutput::WeightsLoaded) + } + Err(err) => Err(format!("model creation error {err:?}")), + }, + WorkerInput::Run => match &self.model { + None => Err("model has not been set yet".to_string()), + Some(model) => { + let result = model.run(&self.link, id).map_err(|e| e.to_string()); + Ok(WorkerOutput::GenerationDone(result)) + } + }, + }; + self.link.respond(id, output); + } + + fn name_of_resource() -> &'static str { + "worker.js" + } + + fn resource_path_is_relative() -> bool { + true + } +} |