summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-31 08:44:32 +0200
committerGitHub <noreply@github.com>2023-08-31 07:44:32 +0100
commit8e84d8a59beeaa0ab051ac0d8febf1b01a234f75 (patch)
treec1e82d1f7765dc605e8e3bba35e7f9a70f2c2519 /candle-wasm-examples/llama2-c
parent9bd486fb96b691c25cbf7eeeb1ac886a982563a6 (diff)
downloadcandle-8e84d8a59beeaa0ab051ac0d8febf1b01a234f75.tar.gz
candle-8e84d8a59beeaa0ab051ac0d8febf1b01a234f75.tar.bz2
candle-8e84d8a59beeaa0ab051ac0d8febf1b01a234f75.zip
Llama2.c wasm module. (#686)
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/m.rs83
-rw-r--r--candle-wasm-examples/llama2-c/src/lib.rs4
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs10
3 files changed, 90 insertions, 7 deletions
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs
new file mode 100644
index 00000000..ba9ed58d
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/bin/m.rs
@@ -0,0 +1,83 @@
+use candle::{Device, Tensor};
+use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData};
+use wasm_bindgen::prelude::*;
+
+#[wasm_bindgen]
+pub struct Model {
+ inner: M,
+ logits_processor: LogitsProcessor,
+ tokens: Vec<u32>,
+}
+
+impl Model {
+ fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
+ let dev = Device::Cpu;
+ let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
+ let logits = self.inner.llama.forward(&input, tokens.len())?;
+ let logits = logits.squeeze(0)?;
+
+ let next_token = self.logits_processor.sample(&logits)?;
+ self.tokens.push(next_token);
+ let text = match self.inner.tokenizer.id_to_token(next_token) {
+ Some(text) => text.replace('▁', " ").replace("<0x0A>", "\n"),
+ None => "".to_string(),
+ };
+ Ok(text)
+ }
+}
+
+#[wasm_bindgen]
+impl Model {
+ #[wasm_bindgen(constructor)]
+ pub fn new(weights: Vec<u8>, tokenizer: Vec<u8>) -> Result<Model, JsError> {
+ let model = M::load(ModelData {
+ tokenizer,
+ model: weights,
+ });
+ let logits_processor = LogitsProcessor::new(299792458, None);
+ match model {
+ Ok(inner) => Ok(Self {
+ inner,
+ logits_processor,
+ tokens: vec![],
+ }),
+ Err(e) => Err(JsError::new(&e.to_string())),
+ }
+ }
+
+ #[wasm_bindgen]
+ pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> {
+ // First reset the cache.
+ {
+ let mut cache = self.inner.cache.kvs.lock().unwrap();
+ for elem in cache.iter_mut() {
+ *elem = None
+ }
+ }
+ let temp = if temp <= 0. { None } else { Some(temp) };
+ self.logits_processor = LogitsProcessor::new(299792458, temp);
+ self.tokens.clear();
+ let tokens = self
+ .inner
+ .tokenizer
+ .encode(prompt.to_string(), true)
+ .map_err(|m| JsError::new(&m.to_string()))?
+ .get_ids()
+ .to_vec();
+ let text = self
+ .process(&tokens)
+ .map_err(|m| JsError::new(&m.to_string()))?;
+ Ok(text)
+ }
+
+ #[wasm_bindgen]
+ pub fn next_token(&mut self) -> Result<String, JsError> {
+ let last_token = *self.tokens.last().unwrap();
+ let text = self
+ .process(&[last_token])
+ .map_err(|m| JsError::new(&m.to_string()))?;
+ Ok(text)
+ }
+}
+
+fn main() {}
diff --git a/candle-wasm-examples/llama2-c/src/lib.rs b/candle-wasm-examples/llama2-c/src/lib.rs
index b6b4004f..cd7834b5 100644
--- a/candle-wasm-examples/llama2-c/src/lib.rs
+++ b/candle-wasm-examples/llama2-c/src/lib.rs
@@ -1,5 +1,5 @@
mod app;
-mod model;
-mod worker;
+pub mod model;
+pub mod worker;
pub use app::App;
pub use worker::Worker;
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index 0ee199af..e15aaa79 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -49,11 +49,11 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
Ok(tensor)
}
-struct Model {
- cache: Cache,
+pub struct Model {
+ pub cache: Cache,
config: Config,
- llama: Llama,
- tokenizer: Tokenizer,
+ pub llama: Llama,
+ pub tokenizer: Tokenizer,
}
pub struct LogitsProcessor {
@@ -275,7 +275,7 @@ impl TransformerWeights {
}
impl Model {
- fn load(md: ModelData) -> Result<Self> {
+ pub fn load(md: ModelData) -> Result<Self> {
let dev = Device::Cpu;
let mut model = std::io::Cursor::new(md.model);
let config = Config::from_reader(&mut model)?;