summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorRadamés Ajna <radamajna@gmail.com>2023-09-26 22:01:59 -0700
committerGitHub <noreply@github.com>2023-09-27 06:01:59 +0100
commit9571b200c9a7c835f66a5444e62f8100e99a4102 (patch)
treefb1e9077e3190264aac388b4aee9cc23e76cfbac /candle-wasm-examples/llama2-c
parentce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99 (diff)
downloadcandle-9571b200c9a7c835f66a5444e62f8100e99a4102.tar.gz
candle-9571b200c9a7c835f66a5444e62f8100e99a4102.tar.bz2
candle-9571b200c9a7c835f66a5444e62f8100e99a4102.zip
fix firstToken, minor ui changes (#971)
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/lib-example.html190
-rw-r--r--candle-wasm-examples/llama2-c/llama2cWorker.js11
2 files changed, 101 insertions, 100 deletions
diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html
index d832aa13..9b78ebde 100644
--- a/candle-wasm-examples/llama2-c/lib-example.html
+++ b/candle-wasm-examples/llama2-c/lib-example.html
@@ -117,6 +117,10 @@
llamaWorker.removeEventListener("message", handleMessage);
reject(new Error(error));
}
+ if (status === "aborted") {
+ llamaWorker.removeEventListener("message", handleMessage);
+ resolve(event.data);
+ }
if (status === "complete") {
llamaWorker.removeEventListener("message", handleMessage);
resolve(event.data);
@@ -212,8 +216,7 @@
<label for="model" class="font-medium">Models Options: </label>
<select
id="model"
- class="border-2 border-gray-500 rounded-md font-light"
- >
+ class="border-2 border-gray-500 rounded-md font-light">
<option value="stories15M" selected>stories 15M (60.8 MB)</option>
<option value="stories42M">stories 42M (167 MB)</option>
<option value="stories110M">stories 110M (438 MB)</option>
@@ -221,133 +224,124 @@
</div>
<form
id="form"
- class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
- >
+ class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
<input type="submit" hidden />
<input
type="text"
id="prompt"
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
placeholder="Add your prompt here..."
- value="Once upon a time"
- />
+ value="Once upon a time" />
<button id="clear-btn">
<svg
fill="none"
xmlns="http://www.w3.org/2000/svg"
width="40"
- viewBox="0 0 70 40"
- >
+ viewBox="0 0 70 40">
<path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" />
<path
d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1"
opacity=".5"
stroke="#1F2937"
- stroke-width="2"
- />
+ stroke-width="2" />
</svg>
</button>
<button
id="run"
- class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
- >
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
Run
</button>
</form>
- <div class="grid grid-cols-3 max-w-md items-center gap-3">
- <label class="text-sm font-medium" for="max-seq">Maximum length </label>
- <input
- type="range"
- id="max-seq"
- name="max-seq"
- min="1"
- max="256"
- step="1"
- value="200"
- oninput="this.nextElementSibling.value = Number(this.value)"
- />
- <output
- class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
- >
- 200</output
- >
- <label class="text-sm font-medium" for="temperature">Temperature</label>
- <input
- type="range"
- id="temperature"
- name="temperature"
- min="0"
- max="2"
- step="0.01"
- value="0.50"
- oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
- />
- <output
- class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
- >
- 0.50</output
- >
- <label class="text-sm font-medium" for="top-p">Top-p</label>
- <input
- type="range"
- id="top-p"
- name="top-p"
- min="0"
- max="1"
- step="0.01"
- value="1.00"
- oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
- />
- <output
- class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
- >
- 1.00</output
- >
+ <details>
+ <summary class="font-medium cursor-pointer">Advanced Options</summary>
+ <div class="grid grid-cols-3 max-w-md items-center gap-3 py-3">
+ <label class="text-sm font-medium" for="max-seq"
+ >Maximum length
+ </label>
+ <input
+ type="range"
+ id="max-seq"
+ name="max-seq"
+ min="1"
+ max="256"
+ step="1"
+ value="200"
+ oninput="this.nextElementSibling.value = Number(this.value)" />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
+ 200</output
+ >
+ <label class="text-sm font-medium" for="temperature"
+ >Temperature</label
+ >
+ <input
+ type="range"
+ id="temperature"
+ name="temperature"
+ min="0"
+ max="2"
+ step="0.01"
+ value="0.40"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
+ 0.40</output
+ >
+ <label class="text-sm font-medium" for="top-p">Top-p</label>
+ <input
+ type="range"
+ id="top-p"
+ name="top-p"
+ min="0"
+ max="1"
+ step="0.01"
+ value="1.00"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
+ 1.00</output
+ >
- <label class="text-sm font-medium" for="repeat_penalty"
- >Repeat Penalty</label
- >
+ <label class="text-sm font-medium" for="repeat_penalty"
+ >Repeat Penalty</label
+ >
- <input
- type="range"
- id="repeat_penalty"
- name="repeat_penalty"
- min="1"
- max="2"
- step="0.01"
- value="1.10"
- oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
- />
- <output
- class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
- >1.10</output
- >
- <label class="text-sm font-medium" for="seed">Seed</label>
- <input
- type="number"
- id="seed"
- name="seed"
- value="299792458"
- class="font-light border border-gray-700 text-right rounded-md p-2"
- />
- <button
- id="run"
- onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))"
- class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm"
- >
- Rand
- </button>
- </div>
+ <input
+ type="range"
+ id="repeat_penalty"
+ name="repeat_penalty"
+ min="1"
+ max="2"
+ step="0.01"
+ value="1.10"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+ >1.10</output
+ >
+ <label class="text-sm font-medium" for="seed">Seed</label>
+ <input
+ type="number"
+ id="seed"
+ name="seed"
+ value="299792458"
+ class="font-light border border-gray-700 text-right rounded-md p-2" />
+ <button
+ id="run"
+ onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))"
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
+ Rand
+ </button>
+ </div>
+ </details>
<div>
<h3 class="font-medium">Generation:</h3>
<div
- class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
- >
+ class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2">
<div
id="output-counter"
hidden
- class="ml-auto font-semibold grid-rows-1 text-sm"
- ></div>
+ class="ml-auto font-semibold grid-rows-1 text-sm"></div>
<p hidden id="output-generation" class="grid-rows-2"></p>
<span id="output-status" class="m-auto font-light"
>No output yet</span
diff --git a/candle-wasm-examples/llama2-c/llama2cWorker.js b/candle-wasm-examples/llama2-c/llama2cWorker.js
index abaf3401..a46b5bc8 100644
--- a/candle-wasm-examples/llama2-c/llama2cWorker.js
+++ b/candle-wasm-examples/llama2-c/llama2cWorker.js
@@ -50,6 +50,7 @@ async function generate(data) {
tokenizerURL,
prompt,
temp,
+ top_p,
repeatPenalty,
seed,
maxSeqLen,
@@ -59,11 +60,17 @@ async function generate(data) {
const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL);
self.postMessage({ status: "loading", message: "Initializing model" });
- model.init_with_prompt(prompt, temp, repeatPenalty, seed);
+ const firstToken = model.init_with_prompt(
+ prompt,
+ temp,
+ top_p,
+ repeatPenalty,
+ seed
+ );
const seq_len = model.get_seq_len();
- let sentence = "";
+ let sentence = firstToken;
let maxTokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1;
let startTime = performance.now();
let tokensCount = 0;