summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c/src/bin/m.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/bin/m.rs')
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/m.rs24
1 files changed, 22 insertions, 2 deletions
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs
index ba9ed58d..ec5f7389 100644
--- a/candle-wasm-examples/llama2-c/src/bin/m.rs
+++ b/candle-wasm-examples/llama2-c/src/bin/m.rs
@@ -1,5 +1,6 @@
use candle::{Device, Tensor};
-use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData};
+use candle_transformers::generation::LogitsProcessor;
+use candle_wasm_example_llama2::worker::{Model as M, ModelData};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
@@ -7,14 +8,26 @@ pub struct Model {
inner: M,
logits_processor: LogitsProcessor,
tokens: Vec<u32>,
+ repeat_penalty: f32,
}
impl Model {
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
+ const REPEAT_LAST_N: usize = 64;
let dev = Device::Cpu;
let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
let logits = self.inner.llama.forward(&input, tokens.len())?;
let logits = logits.squeeze(0)?;
+ let logits = if self.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = self.tokens.len().saturating_sub(REPEAT_LAST_N);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ self.repeat_penalty,
+ &tokens[start_at..],
+ )?
+ };
let next_token = self.logits_processor.sample(&logits)?;
self.tokens.push(next_token);
@@ -40,13 +53,19 @@ impl Model {
inner,
logits_processor,
tokens: vec![],
+ repeat_penalty: 1.,
}),
Err(e) => Err(JsError::new(&e.to_string())),
}
}
#[wasm_bindgen]
- pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> {
+ pub fn init_with_prompt(
+ &mut self,
+ prompt: String,
+ temp: f64,
+ repeat_penalty: f32,
+ ) -> Result<String, JsError> {
// First reset the cache.
{
let mut cache = self.inner.cache.kvs.lock().unwrap();
@@ -56,6 +75,7 @@ impl Model {
}
let temp = if temp <= 0. { None } else { Some(temp) };
self.logits_processor = LogitsProcessor::new(299792458, temp);
+ self.repeat_penalty = repeat_penalty;
self.tokens.clear();
let tokens = self
.inner