1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
|
use candle::{Device, Tensor};
use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct Model {
inner: M,
logits_processor: LogitsProcessor,
tokens: Vec<u32>,
}
impl Model {
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
let dev = Device::Cpu;
let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
let logits = self.inner.llama.forward(&input, tokens.len())?;
let logits = logits.squeeze(0)?;
let next_token = self.logits_processor.sample(&logits)?;
self.tokens.push(next_token);
let text = match self.inner.tokenizer.id_to_token(next_token) {
Some(text) => text.replace('▁', " ").replace("<0x0A>", "\n"),
None => "".to_string(),
};
Ok(text)
}
}
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
pub fn new(weights: Vec<u8>, tokenizer: Vec<u8>) -> Result<Model, JsError> {
let model = M::load(ModelData {
tokenizer,
model: weights,
});
let logits_processor = LogitsProcessor::new(299792458, None);
match model {
Ok(inner) => Ok(Self {
inner,
logits_processor,
tokens: vec![],
}),
Err(e) => Err(JsError::new(&e.to_string())),
}
}
#[wasm_bindgen]
pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> {
// First reset the cache.
{
let mut cache = self.inner.cache.kvs.lock().unwrap();
for elem in cache.iter_mut() {
*elem = None
}
}
let temp = if temp <= 0. { None } else { Some(temp) };
self.logits_processor = LogitsProcessor::new(299792458, temp);
self.tokens.clear();
let tokens = self
.inner
.tokenizer
.encode(prompt.to_string(), true)
.map_err(|m| JsError::new(&m.to_string()))?
.get_ids()
.to_vec();
let text = self
.process(&tokens)
.map_err(|m| JsError::new(&m.to_string()))?;
Ok(text)
}
#[wasm_bindgen]
pub fn next_token(&mut self) -> Result<String, JsError> {
let last_token = *self.tokens.last().unwrap();
let text = self
.process(&[last_token])
.map_err(|m| JsError::new(&m.to_string()))?;
Ok(text)
}
}
fn main() {}
|