diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-01 20:32:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-01 19:32:28 +0100 |
commit | 2fef14cb14f373805a72862daad3a41e5e500dd7 (patch) | |
tree | 177657e875a7240e7c23c710c6c53e2e9ef0d151 /candle-wasm-examples/llama2-c/src/bin/m.rs | |
parent | 1e5b2cc1d5144dcbb86356b99d1aec91dc416473 (diff) | |
download | candle-2fef14cb14f373805a72862daad3a41e5e500dd7.tar.gz candle-2fef14cb14f373805a72862daad3a41e5e500dd7.tar.bz2 candle-2fef14cb14f373805a72862daad3a41e5e500dd7.zip |
Add a repeat penalty to the llama2.c wasm example. (#709)
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/bin/m.rs')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/bin/m.rs | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs index ba9ed58d..ec5f7389 100644 --- a/candle-wasm-examples/llama2-c/src/bin/m.rs +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -1,5 +1,6 @@ use candle::{Device, Tensor}; -use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData}; +use candle_transformers::generation::LogitsProcessor; +use candle_wasm_example_llama2::worker::{Model as M, ModelData}; use wasm_bindgen::prelude::*; #[wasm_bindgen] @@ -7,14 +8,26 @@ pub struct Model { inner: M, logits_processor: LogitsProcessor, tokens: Vec<u32>, + repeat_penalty: f32, } impl Model { fn process(&mut self, tokens: &[u32]) -> candle::Result<String> { + const REPEAT_LAST_N: usize = 64; let dev = Device::Cpu; let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?; let logits = self.inner.llama.forward(&input, tokens.len())?; let logits = logits.squeeze(0)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = self.tokens.len().saturating_sub(REPEAT_LAST_N); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; let next_token = self.logits_processor.sample(&logits)?; self.tokens.push(next_token); @@ -40,13 +53,19 @@ impl Model { inner, logits_processor, tokens: vec![], + repeat_penalty: 1., }), Err(e) => Err(JsError::new(&e.to_string())), } } #[wasm_bindgen] - pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> { + pub fn init_with_prompt( + &mut self, + prompt: String, + temp: f64, + repeat_penalty: f32, + ) -> Result<String, JsError> { // First reset the cache. { let mut cache = self.inner.cache.kvs.lock().unwrap(); @@ -56,6 +75,7 @@ impl Model { } let temp = if temp <= 0. { None } else { Some(temp) }; self.logits_processor = LogitsProcessor::new(299792458, temp); + self.repeat_penalty = repeat_penalty; self.tokens.clear(); let tokens = self .inner |