From 805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f Mon Sep 17 00:00:00 2001
From: Juarez Bochi <jbochi@gmail.com>
Date: Tue, 12 Sep 2023 09:10:16 -0700
Subject: Implement top_p / nucleus sampling (#819)

* Implement top_p / nucleus sampling

* Update changelog

* rustfmt

* Add tests

* Fix clippy warning

* Fix another clippy error
---
 candle-wasm-examples/llama2-c/lib-example.html | 20 +++++++++++++++++++-
 candle-wasm-examples/llama2-c/src/app.rs       | 23 +++++++++++++++++++----
 candle-wasm-examples/llama2-c/src/bin/m.rs     | 10 ++++++++--
 candle-wasm-examples/llama2-c/src/worker.rs    | 16 +++++++++++-----
 4 files changed, 57 insertions(+), 12 deletions(-)

(limited to 'candle-wasm-examples/llama2-c')

diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html
index b5033c54..22b12517 100644
--- a/candle-wasm-examples/llama2-c/lib-example.html
+++ b/candle-wasm-examples/llama2-c/lib-example.html
@@ -56,6 +56,7 @@
         const weightsURL = `${MODELS_BASE_URL}/${model.url}`;
         const prompt = getValue("prompt");
         const temperature = getValue("temperature");
+        const topP = getValue("top-p");
         const repeatPenalty = getValue("repeat_penalty");
         const seed = getValue("seed");
         const maxSeqLen = getValue("max-seq");
@@ -99,6 +100,7 @@
             tokenizerURL: "tokenizer.json",
             prompt,
             temp: temperature,
+            top_p: topP,
             repeatPenalty,
             seed: BigInt(seed),
             maxSeqLen,
@@ -251,7 +253,7 @@
         <input
           type="range"
           id="max-seq"
-          name="temperature"
+          name="max-seq"
           min="1"
           max="256"
           step="1"
@@ -279,6 +281,22 @@
         >
           0.50</output
         >
+        <label class="text-sm font-medium" for="top-p">Top-p</label>
+        <input
+          type="range"
+          id="top-p"
+          name="top-p"
+          min="0"
+          max="1"
+          step="0.01"
+          value="1.00"
+          oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
+        />
+        <output
+          class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+        >
+          1.00</output
+        >
 
         <label class="text-sm font-medium" for="repeat_penalty"
           >Repeat Penalty</label
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs
index 782026a4..ea04a810 100644
--- a/candle-wasm-examples/llama2-c/src/app.rs
+++ b/candle-wasm-examples/llama2-c/src/app.rs
@@ -46,6 +46,7 @@ pub struct App {
     status: String,
     loaded: bool,
     temperature: std::rc::Rc<std::cell::RefCell<f64>>,
+    top_p: std::rc::Rc<std::cell::RefCell<f64>>,
     prompt: std::rc::Rc<std::cell::RefCell<String>>,
     generated: String,
     n_tokens: usize,
@@ -81,6 +82,7 @@ impl Component for App {
             status,
             n_tokens: 0,
             temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
+            top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)),
             prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
             generated: String::new(),
             current_decode: None,
@@ -122,10 +124,11 @@ impl Component for App {
                     self.n_tokens = 0;
                     self.generated.clear();
                     let temp = *self.temperature.borrow();
+                    let top_p = *self.top_p.borrow();
                     let prompt = self.prompt.borrow().clone();
-                    console_log!("temp: {}, prompt: {}", temp, prompt);
+                    console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
                     ctx.link()
-                        .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt)))
+                        .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt)))
                 }
                 true
             }
@@ -177,13 +180,21 @@ impl Component for App {
     fn view(&self, ctx: &Context<Self>) -> Html {
         use yew::TargetCast;
         let temperature = self.temperature.clone();
-        let oninput = ctx.link().callback(move |e: yew::InputEvent| {
+        let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| {
             let input: web_sys::HtmlInputElement = e.target_unchecked_into();
             if let Ok(temp) = f64::from_str(&input.value()) {
                 *temperature.borrow_mut() = temp
             }
             Msg::Refresh
         });
+        let top_p = self.top_p.clone();
+        let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| {
+            let input: web_sys::HtmlInputElement = e.target_unchecked_into();
+            if let Ok(top_p_input) = f64::from_str(&input.value()) {
+                *top_p.borrow_mut() = top_p_input
+            }
+            Msg::Refresh
+        });
         let prompt = self.prompt.clone();
         let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
             let input: web_sys::HtmlInputElement = e.target_unchecked_into();
@@ -201,9 +212,13 @@ impl Component for App {
                 </p>
                 </div>
                 {"temperature  \u{00a0} "}
-                <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
+                <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} oninput={oninput_temperature} id="temp"/>
                 {format!(" \u{00a0} {}", self.temperature.borrow())}
                 <br/ >
+                {"top_p  \u{00a0} "}
+                <input type="range" min="0." max="1.0" step="0.05" value={self.top_p.borrow().to_string()} oninput={oninput_top_p} id="top_p"/>
+                {format!(" \u{00a0} {}", self.top_p.borrow())}
+                <br/ >
                 {"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/>
                 <br/ >
                 {
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs
index 6628ab7e..61de9d7f 100644
--- a/candle-wasm-examples/llama2-c/src/bin/m.rs
+++ b/candle-wasm-examples/llama2-c/src/bin/m.rs
@@ -47,7 +47,7 @@ impl Model {
             tokenizer,
             model: weights,
         });
-        let logits_processor = LogitsProcessor::new(299792458, None);
+        let logits_processor = LogitsProcessor::new(299792458, None, None);
         match model {
             Ok(inner) => Ok(Self {
                 inner,
@@ -69,6 +69,7 @@ impl Model {
         &mut self,
         prompt: String,
         temp: f64,
+        top_p: f64,
         repeat_penalty: f32,
         seed: u64,
     ) -> Result<String, JsError> {
@@ -80,7 +81,12 @@ impl Model {
             }
         }
         let temp = if temp <= 0. { None } else { Some(temp) };
-        self.logits_processor = LogitsProcessor::new(seed, temp);
+        let top_p = if top_p <= 0. || top_p >= 1. {
+            None
+        } else {
+            Some(top_p)
+        };
+        self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
         self.repeat_penalty = repeat_penalty;
         self.tokens.clear();
         let tokens = self
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index 7e97b5da..79dd2f32 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -62,12 +62,18 @@ impl Model {
         link: &WorkerLink<Worker>,
         id: HandlerId,
         temp: f64,
+        top_p: f64,
         prompt: String,
     ) -> Result<()> {
         let dev = Device::Cpu;
         let temp = if temp <= 0. { None } else { Some(temp) };
-        console_log!("{temp:?} {prompt}");
-        let mut logits_processor = LogitsProcessor::new(299792458, temp);
+        let top_p = if top_p <= 0. || top_p >= 1.0 {
+            None
+        } else {
+            Some(top_p)
+        };
+        console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}");
+        let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p);
         let mut index_pos = 0;
         let mut tokens = self
             .tokenizer
@@ -268,7 +274,7 @@ pub struct Worker {
 #[derive(Serialize, Deserialize)]
 pub enum WorkerInput {
     ModelData(ModelData),
-    Run(f64, String),
+    Run(f64, f64, String),
 }
 
 #[derive(Serialize, Deserialize)]
@@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker {
                 }
                 Err(err) => Err(format!("model creation error {err:?}")),
             },
-            WorkerInput::Run(temp, prompt) => match &mut self.model {
+            WorkerInput::Run(temp, top_p, prompt) => match &mut self.model {
                 None => Err("model has not been set yet".to_string()),
                 Some(model) => {
                     {
@@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker {
                         }
                     }
                     let result = model
-                        .run(&self.link, id, temp, prompt)
+                        .run(&self.link, id, temp, top_p, prompt)
                         .map_err(|e| e.to_string());
                     Ok(WorkerOutput::GenerationDone(result))
                 }
-- 
cgit v1.2.3