1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
import init, { Model } from "./build/m.js";
async function fetchArrayBuffer(url) {
const res = await fetch(url, {
cache: "force-cache",
});
const data = await res.arrayBuffer();
return new Uint8Array(data);
}
class Llama2C {
static instance = {};
static async getInstance(weightsURL, modelID, tokenizerURL) {
// load individual modelID only once
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
]);
this.instance[modelID] = new Model(weightsArrayU8, tokenizerArrayU8);
}
return this.instance[modelID];
}
}
let controller = null;
self.addEventListener("message", (event) => {
if (event.data.command === "start") {
controller = new AbortController();
generate(event.data);
} else if (event.data.command === "abort") {
controller.abort();
}
});
async function generate(data) {
const {
weightsURL,
modelID,
tokenizerURL,
prompt,
temp,
repeatPenalty,
seed,
maxSeqLen,
} = data;
try {
self.postMessage({ status: "loading", message: "Starting llama2.c" });
const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL);
self.postMessage({ status: "loading", message: "Initializing model" });
model.init_with_prompt(prompt, temp, repeatPenalty, seed);
const seq_len = model.get_seq_len();
let sentence = "";
let max_tokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1;
while (max_tokens--) {
await new Promise(async (resolve) => {
if (controller && controller.signal.aborted) {
self.postMessage({
status: "aborted",
message: "Aborted",
output: prompt + sentence,
});
return;
}
const token = await model.next_token();
sentence += token;
self.postMessage({
status: "generating",
message: "Generating token",
token: token,
sentence: sentence,
prompt: prompt,
});
setTimeout(resolve, 0);
});
}
self.postMessage({
status: "complete",
message: "complete",
output: prompt + sentence,
});
} catch (e) {
self.postMessage({ error: e });
}
}
|