use candle::{Device, Tensor}; use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData}; use wasm_bindgen::prelude::*; #[wasm_bindgen] pub struct Model { inner: M, logits_processor: LogitsProcessor, tokens: Vec, } impl Model { fn process(&mut self, tokens: &[u32]) -> candle::Result { let dev = Device::Cpu; let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?; let logits = self.inner.llama.forward(&input, tokens.len())?; let logits = logits.squeeze(0)?; let next_token = self.logits_processor.sample(&logits)?; self.tokens.push(next_token); let text = match self.inner.tokenizer.id_to_token(next_token) { Some(text) => text.replace('▁', " ").replace("<0x0A>", "\n"), None => "".to_string(), }; Ok(text) } } #[wasm_bindgen] impl Model { #[wasm_bindgen(constructor)] pub fn new(weights: Vec, tokenizer: Vec) -> Result { let model = M::load(ModelData { tokenizer, model: weights, }); let logits_processor = LogitsProcessor::new(299792458, None); match model { Ok(inner) => Ok(Self { inner, logits_processor, tokens: vec![], }), Err(e) => Err(JsError::new(&e.to_string())), } } #[wasm_bindgen] pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result { // First reset the cache. { let mut cache = self.inner.cache.kvs.lock().unwrap(); for elem in cache.iter_mut() { *elem = None } } let temp = if temp <= 0. { None } else { Some(temp) }; self.logits_processor = LogitsProcessor::new(299792458, temp); self.tokens.clear(); let tokens = self .inner .tokenizer .encode(prompt.to_string(), true) .map_err(|m| JsError::new(&m.to_string()))? .get_ids() .to_vec(); let text = self .process(&tokens) .map_err(|m| JsError::new(&m.to_string()))?; Ok(text) } #[wasm_bindgen] pub fn next_token(&mut self) -> Result { let last_token = *self.tokens.last().unwrap(); let text = self .process(&[last_token]) .map_err(|m| JsError::new(&m.to_string()))?; Ok(text) } } fn main() {}