summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c/src/worker.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-01 20:32:28 +0200
committerGitHub <noreply@github.com>2023-09-01 19:32:28 +0100
commit2fef14cb14f373805a72862daad3a41e5e500dd7 (patch)
tree177657e875a7240e7c23c710c6c53e2e9ef0d151 /candle-wasm-examples/llama2-c/src/worker.rs
parent1e5b2cc1d5144dcbb86356b99d1aec91dc416473 (diff)
downloadcandle-2fef14cb14f373805a72862daad3a41e5e500dd7.tar.gz
candle-2fef14cb14f373805a72862daad3a41e5e500dd7.tar.bz2
candle-2fef14cb14f373805a72862daad3a41e5e500dd7.zip
Add a repeat penalty to the llama2.c wasm example. (#709)
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/worker.rs')
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs40
1 files changed, 3 insertions, 37 deletions
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index e15aaa79..3d187fcc 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -1,8 +1,8 @@
use crate::model::{Cache, Config, Llama};
use byteorder::{LittleEndian, ReadBytesExt};
-use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
-use candle_nn::{ops::softmax, VarBuilder};
-use rand::{distributions::Distribution, SeedableRng};
+use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::generation::LogitsProcessor;
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
@@ -56,40 +56,6 @@ pub struct Model {
pub tokenizer: Tokenizer,
}
-pub struct LogitsProcessor {
- rng: rand::rngs::StdRng,
- temperature: Option<f64>,
-}
-
-impl LogitsProcessor {
- pub fn new(seed: u64, temperature: Option<f64>) -> Self {
- Self {
- rng: rand::rngs::StdRng::seed_from_u64(seed),
- temperature,
- }
- }
-
- pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
- let logits = logits.to_dtype(DType::F32)?;
- let next_token = if let Some(temperature) = self.temperature {
- let prs = softmax(&(&logits / temperature)?, D::Minus1)?;
- let prs: Vec<f32> = prs.to_vec1()?;
- let distr =
- rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;
- distr.sample(&mut self.rng) as u32
- } else {
- let logits_v: Vec<f32> = logits.to_vec1()?;
- logits_v
- .iter()
- .enumerate()
- .max_by(|(_, u), (_, v)| u.total_cmp(v))
- .map(|(i, _)| i as u32)
- .unwrap()
- };
- Ok(next_token)
- }
-}
-
impl Model {
fn run(
&self,