diff options
author | Juarez Bochi <jbochi@gmail.com> | 2023-09-12 09:10:16 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-12 18:10:16 +0200 |
commit | 805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f (patch) | |
tree | 0df65e2e6fee356d2345954701ec3d47796ae7ee /candle-wasm-examples/llama2-c | |
parent | 42da17694a4214a3e39e0d64afc22635ce83f557 (diff) | |
download | candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.gz candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.bz2 candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.zip |
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling
* Update changelog
* rustfmt
* Add tests
* Fix clippy warning
* Fix another clippy error
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/lib-example.html | 20 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/app.rs | 23 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/bin/m.rs | 10 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 16 |
4 files changed, 57 insertions, 12 deletions
diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html index b5033c54..22b12517 100644 --- a/candle-wasm-examples/llama2-c/lib-example.html +++ b/candle-wasm-examples/llama2-c/lib-example.html @@ -56,6 +56,7 @@ const weightsURL = `${MODELS_BASE_URL}/${model.url}`; const prompt = getValue("prompt"); const temperature = getValue("temperature"); + const topP = getValue("top-p"); const repeatPenalty = getValue("repeat_penalty"); const seed = getValue("seed"); const maxSeqLen = getValue("max-seq"); @@ -99,6 +100,7 @@ tokenizerURL: "tokenizer.json", prompt, temp: temperature, + top_p: topP, repeatPenalty, seed: BigInt(seed), maxSeqLen, @@ -251,7 +253,7 @@ <input type="range" id="max-seq" - name="temperature" + name="max-seq" min="1" max="256" step="1" @@ -279,6 +281,22 @@ > 0.50</output > + <label class="text-sm font-medium" for="top-p">Top-p</label> + <input + type="range" + id="top-p" + name="top-p" + min="0" + max="1" + step="0.01" + value="1.00" + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" + /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + > + 1.00</output + > <label class="text-sm font-medium" for="repeat_penalty" >Repeat Penalty</label diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs index 782026a4..ea04a810 100644 --- a/candle-wasm-examples/llama2-c/src/app.rs +++ b/candle-wasm-examples/llama2-c/src/app.rs @@ -46,6 +46,7 @@ pub struct App { status: String, loaded: bool, temperature: std::rc::Rc<std::cell::RefCell<f64>>, + top_p: std::rc::Rc<std::cell::RefCell<f64>>, prompt: std::rc::Rc<std::cell::RefCell<String>>, generated: String, n_tokens: usize, @@ -81,6 +82,7 @@ impl Component for App { status, n_tokens: 0, temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)), + top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)), prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())), generated: String::new(), current_decode: None, @@ -122,10 +124,11 @@ impl Component for App { self.n_tokens = 0; self.generated.clear(); let temp = *self.temperature.borrow(); + let top_p = *self.top_p.borrow(); let prompt = self.prompt.borrow().clone(); - console_log!("temp: {}, prompt: {}", temp, prompt); + console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt); ctx.link() - .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt))) + .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt))) } true } @@ -177,13 +180,21 @@ impl Component for App { fn view(&self, ctx: &Context<Self>) -> Html { use yew::TargetCast; let temperature = self.temperature.clone(); - let oninput = ctx.link().callback(move |e: yew::InputEvent| { + let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| { let input: web_sys::HtmlInputElement = e.target_unchecked_into(); if let Ok(temp) = f64::from_str(&input.value()) { *temperature.borrow_mut() = temp } Msg::Refresh }); + let top_p = self.top_p.clone(); + let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| { + let input: web_sys::HtmlInputElement = e.target_unchecked_into(); + if let Ok(top_p_input) = f64::from_str(&input.value()) { + *top_p.borrow_mut() = top_p_input + } + Msg::Refresh + }); let prompt = self.prompt.clone(); let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| { let input: web_sys::HtmlInputElement = e.target_unchecked_into(); @@ -201,9 +212,13 @@ impl Component for App { </p> </div> {"temperature \u{00a0} "} - <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/> + <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} oninput={oninput_temperature} id="temp"/> {format!(" \u{00a0} {}", self.temperature.borrow())} <br/ > + {"top_p \u{00a0} "} + <input type="range" min="0." max="1.0" step="0.05" value={self.top_p.borrow().to_string()} oninput={oninput_top_p} id="top_p"/> + {format!(" \u{00a0} {}", self.top_p.borrow())} + <br/ > {"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/> <br/ > { diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs index 6628ab7e..61de9d7f 100644 --- a/candle-wasm-examples/llama2-c/src/bin/m.rs +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -47,7 +47,7 @@ impl Model { tokenizer, model: weights, }); - let logits_processor = LogitsProcessor::new(299792458, None); + let logits_processor = LogitsProcessor::new(299792458, None, None); match model { Ok(inner) => Ok(Self { inner, @@ -69,6 +69,7 @@ impl Model { &mut self, prompt: String, temp: f64, + top_p: f64, repeat_penalty: f32, seed: u64, ) -> Result<String, JsError> { @@ -80,7 +81,12 @@ impl Model { } } let temp = if temp <= 0. { None } else { Some(temp) }; - self.logits_processor = LogitsProcessor::new(seed, temp); + let top_p = if top_p <= 0. || top_p >= 1. { + None + } else { + Some(top_p) + }; + self.logits_processor = LogitsProcessor::new(seed, temp, top_p); self.repeat_penalty = repeat_penalty; self.tokens.clear(); let tokens = self diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 7e97b5da..79dd2f32 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -62,12 +62,18 @@ impl Model { link: &WorkerLink<Worker>, id: HandlerId, temp: f64, + top_p: f64, prompt: String, ) -> Result<()> { let dev = Device::Cpu; let temp = if temp <= 0. { None } else { Some(temp) }; - console_log!("{temp:?} {prompt}"); - let mut logits_processor = LogitsProcessor::new(299792458, temp); + let top_p = if top_p <= 0. || top_p >= 1.0 { + None + } else { + Some(top_p) + }; + console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}"); + let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p); let mut index_pos = 0; let mut tokens = self .tokenizer @@ -268,7 +274,7 @@ pub struct Worker { #[derive(Serialize, Deserialize)] pub enum WorkerInput { ModelData(ModelData), - Run(f64, String), + Run(f64, f64, String), } #[derive(Serialize, Deserialize)] @@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker { } Err(err) => Err(format!("model creation error {err:?}")), }, - WorkerInput::Run(temp, prompt) => match &mut self.model { + WorkerInput::Run(temp, top_p, prompt) => match &mut self.model { None => Err("model has not been set yet".to_string()), Some(model) => { { @@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker { } } let result = model - .run(&self.link, id, temp, prompt) + .run(&self.link, id, temp, top_p, prompt) .map_err(|e| e.to_string()); Ok(WorkerOutput::GenerationDone(result)) } |