diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-01 20:32:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-01 19:32:28 +0100 |
commit | 2fef14cb14f373805a72862daad3a41e5e500dd7 (patch) | |
tree | 177657e875a7240e7c23c710c6c53e2e9ef0d151 /candle-wasm-examples/llama2-c/src/worker.rs | |
parent | 1e5b2cc1d5144dcbb86356b99d1aec91dc416473 (diff) | |
download | candle-2fef14cb14f373805a72862daad3a41e5e500dd7.tar.gz candle-2fef14cb14f373805a72862daad3a41e5e500dd7.tar.bz2 candle-2fef14cb14f373805a72862daad3a41e5e500dd7.zip |
Add a repeat penalty to the llama2.c wasm example. (#709)
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/worker.rs')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 40 |
1 files changed, 3 insertions, 37 deletions
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index e15aaa79..3d187fcc 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -1,8 +1,8 @@ use crate::model::{Cache, Config, Llama}; use byteorder::{LittleEndian, ReadBytesExt}; -use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D}; -use candle_nn::{ops::softmax, VarBuilder}; -use rand::{distributions::Distribution, SeedableRng}; +use candle::{DType, Device, IndexOp, Result, Shape, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; use serde::{Deserialize, Serialize}; use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; @@ -56,40 +56,6 @@ pub struct Model { pub tokenizer: Tokenizer, } -pub struct LogitsProcessor { - rng: rand::rngs::StdRng, - temperature: Option<f64>, -} - -impl LogitsProcessor { - pub fn new(seed: u64, temperature: Option<f64>) -> Self { - Self { - rng: rand::rngs::StdRng::seed_from_u64(seed), - temperature, - } - } - - pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { - let logits = logits.to_dtype(DType::F32)?; - let next_token = if let Some(temperature) = self.temperature { - let prs = softmax(&(&logits / temperature)?, D::Minus1)?; - let prs: Vec<f32> = prs.to_vec1()?; - let distr = - rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?; - distr.sample(&mut self.rng) as u32 - } else { - let logits_v: Vec<f32> = logits.to_vec1()?; - logits_v - .iter() - .enumerate() - .max_by(|(_, u), (_, v)| u.total_cmp(v)) - .map(|(i, _)| i as u32) - .unwrap() - }; - Ok(next_token) - } -} - impl Model { fn run( &self, |