summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorJuarez Bochi <jbochi@gmail.com>2023-09-12 09:10:16 -0700
committerGitHub <noreply@github.com>2023-09-12 18:10:16 +0200
commit805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f (patch)
tree0df65e2e6fee356d2345954701ec3d47796ae7ee /candle-wasm-examples/llama2-c
parent42da17694a4214a3e39e0d64afc22635ce83f557 (diff)
downloadcandle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.gz
candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.bz2
candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.zip
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/lib-example.html20
-rw-r--r--candle-wasm-examples/llama2-c/src/app.rs23
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/m.rs10
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs16
4 files changed, 57 insertions, 12 deletions
diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html
index b5033c54..22b12517 100644
--- a/candle-wasm-examples/llama2-c/lib-example.html
+++ b/candle-wasm-examples/llama2-c/lib-example.html
@@ -56,6 +56,7 @@
const weightsURL = `${MODELS_BASE_URL}/${model.url}`;
const prompt = getValue("prompt");
const temperature = getValue("temperature");
+ const topP = getValue("top-p");
const repeatPenalty = getValue("repeat_penalty");
const seed = getValue("seed");
const maxSeqLen = getValue("max-seq");
@@ -99,6 +100,7 @@
tokenizerURL: "tokenizer.json",
prompt,
temp: temperature,
+ top_p: topP,
repeatPenalty,
seed: BigInt(seed),
maxSeqLen,
@@ -251,7 +253,7 @@
<input
type="range"
id="max-seq"
- name="temperature"
+ name="max-seq"
min="1"
max="256"
step="1"
@@ -279,6 +281,22 @@
>
0.50</output
>
+ <label class="text-sm font-medium" for="top-p">Top-p</label>
+ <input
+ type="range"
+ id="top-p"
+ name="top-p"
+ min="0"
+ max="1"
+ step="0.01"
+ value="1.00"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
+ />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+ >
+ 1.00</output
+ >
<label class="text-sm font-medium" for="repeat_penalty"
>Repeat Penalty</label
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs
index 782026a4..ea04a810 100644
--- a/candle-wasm-examples/llama2-c/src/app.rs
+++ b/candle-wasm-examples/llama2-c/src/app.rs
@@ -46,6 +46,7 @@ pub struct App {
status: String,
loaded: bool,
temperature: std::rc::Rc<std::cell::RefCell<f64>>,
+ top_p: std::rc::Rc<std::cell::RefCell<f64>>,
prompt: std::rc::Rc<std::cell::RefCell<String>>,
generated: String,
n_tokens: usize,
@@ -81,6 +82,7 @@ impl Component for App {
status,
n_tokens: 0,
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
+ top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)),
prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
generated: String::new(),
current_decode: None,
@@ -122,10 +124,11 @@ impl Component for App {
self.n_tokens = 0;
self.generated.clear();
let temp = *self.temperature.borrow();
+ let top_p = *self.top_p.borrow();
let prompt = self.prompt.borrow().clone();
- console_log!("temp: {}, prompt: {}", temp, prompt);
+ console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
ctx.link()
- .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt)))
+ .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt)))
}
true
}
@@ -177,13 +180,21 @@ impl Component for App {
fn view(&self, ctx: &Context<Self>) -> Html {
use yew::TargetCast;
let temperature = self.temperature.clone();
- let oninput = ctx.link().callback(move |e: yew::InputEvent| {
+ let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
if let Ok(temp) = f64::from_str(&input.value()) {
*temperature.borrow_mut() = temp
}
Msg::Refresh
});
+ let top_p = self.top_p.clone();
+ let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| {
+ let input: web_sys::HtmlInputElement = e.target_unchecked_into();
+ if let Ok(top_p_input) = f64::from_str(&input.value()) {
+ *top_p.borrow_mut() = top_p_input
+ }
+ Msg::Refresh
+ });
let prompt = self.prompt.clone();
let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
@@ -201,9 +212,13 @@ impl Component for App {
</p>
</div>
{"temperature \u{00a0} "}
- <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
+ <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} oninput={oninput_temperature} id="temp"/>
{format!(" \u{00a0} {}", self.temperature.borrow())}
<br/ >
+ {"top_p \u{00a0} "}
+ <input type="range" min="0." max="1.0" step="0.05" value={self.top_p.borrow().to_string()} oninput={oninput_top_p} id="top_p"/>
+ {format!(" \u{00a0} {}", self.top_p.borrow())}
+ <br/ >
{"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/>
<br/ >
{
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs
index 6628ab7e..61de9d7f 100644
--- a/candle-wasm-examples/llama2-c/src/bin/m.rs
+++ b/candle-wasm-examples/llama2-c/src/bin/m.rs
@@ -47,7 +47,7 @@ impl Model {
tokenizer,
model: weights,
});
- let logits_processor = LogitsProcessor::new(299792458, None);
+ let logits_processor = LogitsProcessor::new(299792458, None, None);
match model {
Ok(inner) => Ok(Self {
inner,
@@ -69,6 +69,7 @@ impl Model {
&mut self,
prompt: String,
temp: f64,
+ top_p: f64,
repeat_penalty: f32,
seed: u64,
) -> Result<String, JsError> {
@@ -80,7 +81,12 @@ impl Model {
}
}
let temp = if temp <= 0. { None } else { Some(temp) };
- self.logits_processor = LogitsProcessor::new(seed, temp);
+ let top_p = if top_p <= 0. || top_p >= 1. {
+ None
+ } else {
+ Some(top_p)
+ };
+ self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
self.repeat_penalty = repeat_penalty;
self.tokens.clear();
let tokens = self
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index 7e97b5da..79dd2f32 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -62,12 +62,18 @@ impl Model {
link: &WorkerLink<Worker>,
id: HandlerId,
temp: f64,
+ top_p: f64,
prompt: String,
) -> Result<()> {
let dev = Device::Cpu;
let temp = if temp <= 0. { None } else { Some(temp) };
- console_log!("{temp:?} {prompt}");
- let mut logits_processor = LogitsProcessor::new(299792458, temp);
+ let top_p = if top_p <= 0. || top_p >= 1.0 {
+ None
+ } else {
+ Some(top_p)
+ };
+ console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}");
+ let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p);
let mut index_pos = 0;
let mut tokens = self
.tokenizer
@@ -268,7 +274,7 @@ pub struct Worker {
#[derive(Serialize, Deserialize)]
pub enum WorkerInput {
ModelData(ModelData),
- Run(f64, String),
+ Run(f64, f64, String),
}
#[derive(Serialize, Deserialize)]
@@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker {
}
Err(err) => Err(format!("model creation error {err:?}")),
},
- WorkerInput::Run(temp, prompt) => match &mut self.model {
+ WorkerInput::Run(temp, top_p, prompt) => match &mut self.model {
None => Err("model has not been set yet".to_string()),
Some(model) => {
{
@@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker {
}
}
let result = model
- .run(&self.link, id, temp, prompt)
+ .run(&self.link, id, temp, top_p, prompt)
.map_err(|e| e.to_string());
Ok(WorkerOutput::GenerationDone(result))
}