diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-31 08:44:32 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-31 07:44:32 +0100 |
commit | 8e84d8a59beeaa0ab051ac0d8febf1b01a234f75 (patch) | |
tree | c1e82d1f7765dc605e8e3bba35e7f9a70f2c2519 /candle-wasm-examples/llama2-c | |
parent | 9bd486fb96b691c25cbf7eeeb1ac886a982563a6 (diff) | |
download | candle-8e84d8a59beeaa0ab051ac0d8febf1b01a234f75.tar.gz candle-8e84d8a59beeaa0ab051ac0d8febf1b01a234f75.tar.bz2 candle-8e84d8a59beeaa0ab051ac0d8febf1b01a234f75.zip |
Llama2.c wasm module. (#686)
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/bin/m.rs | 83 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/lib.rs | 4 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 10 |
3 files changed, 90 insertions, 7 deletions
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs new file mode 100644 index 00000000..ba9ed58d --- /dev/null +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -0,0 +1,83 @@ +use candle::{Device, Tensor}; +use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData}; +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +pub struct Model { + inner: M, + logits_processor: LogitsProcessor, + tokens: Vec<u32>, +} + +impl Model { + fn process(&mut self, tokens: &[u32]) -> candle::Result<String> { + let dev = Device::Cpu; + let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?; + let logits = self.inner.llama.forward(&input, tokens.len())?; + let logits = logits.squeeze(0)?; + + let next_token = self.logits_processor.sample(&logits)?; + self.tokens.push(next_token); + let text = match self.inner.tokenizer.id_to_token(next_token) { + Some(text) => text.replace('▁', " ").replace("<0x0A>", "\n"), + None => "".to_string(), + }; + Ok(text) + } +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn new(weights: Vec<u8>, tokenizer: Vec<u8>) -> Result<Model, JsError> { + let model = M::load(ModelData { + tokenizer, + model: weights, + }); + let logits_processor = LogitsProcessor::new(299792458, None); + match model { + Ok(inner) => Ok(Self { + inner, + logits_processor, + tokens: vec![], + }), + Err(e) => Err(JsError::new(&e.to_string())), + } + } + + #[wasm_bindgen] + pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> { + // First reset the cache. + { + let mut cache = self.inner.cache.kvs.lock().unwrap(); + for elem in cache.iter_mut() { + *elem = None + } + } + let temp = if temp <= 0. { None } else { Some(temp) }; + self.logits_processor = LogitsProcessor::new(299792458, temp); + self.tokens.clear(); + let tokens = self + .inner + .tokenizer + .encode(prompt.to_string(), true) + .map_err(|m| JsError::new(&m.to_string()))? + .get_ids() + .to_vec(); + let text = self + .process(&tokens) + .map_err(|m| JsError::new(&m.to_string()))?; + Ok(text) + } + + #[wasm_bindgen] + pub fn next_token(&mut self) -> Result<String, JsError> { + let last_token = *self.tokens.last().unwrap(); + let text = self + .process(&[last_token]) + .map_err(|m| JsError::new(&m.to_string()))?; + Ok(text) + } +} + +fn main() {} diff --git a/candle-wasm-examples/llama2-c/src/lib.rs b/candle-wasm-examples/llama2-c/src/lib.rs index b6b4004f..cd7834b5 100644 --- a/candle-wasm-examples/llama2-c/src/lib.rs +++ b/candle-wasm-examples/llama2-c/src/lib.rs @@ -1,5 +1,5 @@ mod app; -mod model; -mod worker; +pub mod model; +pub mod worker; pub use app::App; pub use worker::Worker; diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 0ee199af..e15aaa79 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -49,11 +49,11 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>( Ok(tensor) } -struct Model { - cache: Cache, +pub struct Model { + pub cache: Cache, config: Config, - llama: Llama, - tokenizer: Tokenizer, + pub llama: Llama, + pub tokenizer: Tokenizer, } pub struct LogitsProcessor { @@ -275,7 +275,7 @@ impl TransformerWeights { } impl Model { - fn load(md: ModelData) -> Result<Self> { + pub fn load(md: ModelData) -> Result<Self> { let dev = Device::Cpu; let mut model = std::io::Cursor::new(md.model); let config = Config::from_reader(&mut model)?; |