summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c/src/bin/m.rs
blob: ba9ed58d19ef929906542b6ef4f0803ef057549a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
use candle::{Device, Tensor};
use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData};
use wasm_bindgen::prelude::*;

#[wasm_bindgen]
pub struct Model {
    inner: M,
    logits_processor: LogitsProcessor,
    tokens: Vec<u32>,
}

impl Model {
    fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
        let dev = Device::Cpu;
        let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
        let logits = self.inner.llama.forward(&input, tokens.len())?;
        let logits = logits.squeeze(0)?;

        let next_token = self.logits_processor.sample(&logits)?;
        self.tokens.push(next_token);
        let text = match self.inner.tokenizer.id_to_token(next_token) {
            Some(text) => text.replace('▁', " ").replace("<0x0A>", "\n"),
            None => "".to_string(),
        };
        Ok(text)
    }
}

#[wasm_bindgen]
impl Model {
    #[wasm_bindgen(constructor)]
    pub fn new(weights: Vec<u8>, tokenizer: Vec<u8>) -> Result<Model, JsError> {
        let model = M::load(ModelData {
            tokenizer,
            model: weights,
        });
        let logits_processor = LogitsProcessor::new(299792458, None);
        match model {
            Ok(inner) => Ok(Self {
                inner,
                logits_processor,
                tokens: vec![],
            }),
            Err(e) => Err(JsError::new(&e.to_string())),
        }
    }

    #[wasm_bindgen]
    pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> {
        // First reset the cache.
        {
            let mut cache = self.inner.cache.kvs.lock().unwrap();
            for elem in cache.iter_mut() {
                *elem = None
            }
        }
        let temp = if temp <= 0. { None } else { Some(temp) };
        self.logits_processor = LogitsProcessor::new(299792458, temp);
        self.tokens.clear();
        let tokens = self
            .inner
            .tokenizer
            .encode(prompt.to_string(), true)
            .map_err(|m| JsError::new(&m.to_string()))?
            .get_ids()
            .to_vec();
        let text = self
            .process(&tokens)
            .map_err(|m| JsError::new(&m.to_string()))?;
        Ok(text)
    }

    #[wasm_bindgen]
    pub fn next_token(&mut self) -> Result<String, JsError> {
        let last_token = *self.tokens.last().unwrap();
        let text = self
            .process(&[last_token])
            .map_err(|m| JsError::new(&m.to_string()))?;
        Ok(text)
    }
}

fn main() {}