summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/t5/utils.js
diff options
context:
space:
mode:
Diffstat (limited to 'candle-wasm-examples/t5/utils.js')
-rw-r--r--candle-wasm-examples/t5/utils.js168
1 files changed, 168 insertions, 0 deletions
diff --git a/candle-wasm-examples/t5/utils.js b/candle-wasm-examples/t5/utils.js
new file mode 100644
index 00000000..e45e7d1b
--- /dev/null
+++ b/candle-wasm-examples/t5/utils.js
@@ -0,0 +1,168 @@
+export async function extractEmbeddings(
+ worker,
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ updateStatus,
+ normalize_embeddings = true
+) {
+ return new Promise((resolve, reject) => {
+ worker.postMessage({
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ normalize_embeddings,
+ });
+ function messageHandler(event) {
+ if ("error" in event.data) {
+ worker.removeEventListener("message", messageHandler);
+ reject(new Error(event.data.error));
+ }
+ if (event.data.status === "complete") {
+ worker.removeEventListener("message", messageHandler);
+ resolve(event.data);
+ }
+ if (updateStatus) updateStatus(event.data);
+ }
+ worker.addEventListener("message", messageHandler);
+ });
+}
+
+export async function generateText(
+ worker,
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ prompt,
+ params,
+ updateStatus
+) {
+ return new Promise((resolve, reject) => {
+ worker.postMessage({
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ prompt,
+ params,
+ });
+ function messageHandler(event) {
+ if ("error" in event.data) {
+ worker.removeEventListener("message", messageHandler);
+ reject(new Error(event.data.error));
+ }
+ if (event.data.status === "complete") {
+ worker.removeEventListener("message", messageHandler);
+ resolve(event.data);
+ }
+ if (updateStatus) updateStatus(event.data);
+ }
+ worker.addEventListener("message", messageHandler);
+ });
+}
+export const MODELS = {
+ t5_small_quantized: {
+ size: "102 MB",
+ base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/",
+ model: "model.gguf",
+ tokenizer: "tokenizer.json",
+ config: "config.json",
+ tasks: {
+ translation_en_to_de: {
+ prefix: "translate English to German: ",
+ max_length: 300,
+ },
+ translation_en_to_fr: {
+ prefix: "translate English to French: ",
+ max_length: 300,
+ },
+ translation_en_to_ro: {
+ prefix: "translate English to Romanian: ",
+ max_length: 300,
+ },
+ summarization: { prefix: "summarize: ", max_length: 200 },
+ },
+ },
+ t5_small: {
+ size: "242 MB",
+ base_url: "https://huggingface.co/t5-small/resolve/main/",
+ model: "model.safetensors",
+ tokenizer: "tokenizer.json",
+ config: "config.json",
+ tasks: {
+ translation_en_to_de: {
+ prefix: "translate English to German: ",
+ max_length: 300,
+ },
+ translation_en_to_fr: {
+ prefix: "translate English to French: ",
+ max_length: 300,
+ },
+ translation_en_to_ro: {
+ prefix: "translate English to Romanian: ",
+ max_length: 300,
+ },
+ summarization: { prefix: "summarize: ", max_length: 200 },
+ },
+ },
+ flan_t5_small: {
+ size: "308 MB",
+ base_url:
+ "https://huggingface.co/google/flan-t5-small/resolve/refs%2Fpr%2F14/",
+ model: "model.safetensors",
+ tokenizer: "tokenizer.json",
+ config: "config.json",
+ tasks: {
+ translation_en_to_de: {
+ prefix: "translate English to German: ",
+ max_length: 300,
+ },
+ translation_en_to_fr: {
+ prefix: "translate English to French: ",
+ max_length: 300,
+ },
+ translation_en_to_ro: {
+ prefix: "translate English to Romanian: ",
+ max_length: 300,
+ },
+ summarization: { prefix: "summarize: ", max_length: 200 },
+ },
+ },
+
+ flan_t5_base_quantized: {
+ size: "360 MB",
+ base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/",
+ model: "model-flan-t5-base.gguf",
+ tokenizer: "tokenizer.json",
+ config: "config-flan-t5-base.json",
+ tasks: {
+ translation_en_to_de: {
+ prefix: "translate English to German: ",
+ max_length: 300,
+ },
+ translation_en_to_fr: {
+ prefix: "translate English to French: ",
+ max_length: 300,
+ },
+ translation_en_to_ro: {
+ prefix: "translate English to Romanian: ",
+ max_length: 300,
+ },
+ summarization: { prefix: "summarize: ", max_length: 200 },
+ },
+ },
+};
+export function getModelInfo(id, taskID) {
+ const model = MODELS[id];
+ return {
+ modelURL: model.base_url + model.model,
+ configURL: model.base_url + model.config,
+ tokenizerURL: model.base_url + model.tokenizer,
+ maxLength: model.tasks[taskID].max_length,
+ };
+}