summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJuarez Bochi <juarez.bochi@grammarly.com>2023-10-04 15:57:33 -0400
committerGitHub <noreply@github.com>2023-10-04 20:57:33 +0100
commitb86ac0c5076534e2a7c067e87d1125d4da21cd22 (patch)
tree64d85729cad805643c08365ca9155e441d3c7619
parent27e70a50939b647a7c2e80428647f5668e592607 (diff)
downloadcandle-b86ac0c5076534e2a7c067e87d1125d4da21cd22.tar.gz
candle-b86ac0c5076534e2a7c067e87d1125d4da21cd22.tar.bz2
candle-b86ac0c5076534e2a7c067e87d1125d4da21cd22.zip
Quant t5: Add coedit model to wasm demo and readme (#1031)
-rw-r--r--candle-examples/examples/quantized-t5/README.md27
-rw-r--r--candle-wasm-examples/t5/index.html12
-rw-r--r--candle-wasm-examples/t5/utils.js36
3 files changed, 70 insertions, 5 deletions
diff --git a/candle-examples/examples/quantized-t5/README.md b/candle-examples/examples/quantized-t5/README.md
index 1f6b99eb..4a1ee5bf 100644
--- a/candle-examples/examples/quantized-t5/README.md
+++ b/candle-examples/examples/quantized-t5/README.md
@@ -13,5 +13,30 @@ generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:
```bash
-cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
+$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
+```
+
+To use a different model, specify the `model-id`. For example, you can use
+quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
+
+```bash
+$ cargo run --example quantized-t5 --release -- \
+ --model-id "jbochi/candle-coedit-quantized" \
+ --prompt "Make this text coherent: Their flight is weak. They run quickly through the tree canopy." \
+ --temperature 0
+...
+ Although their flight is weak, they run quickly through the tree canopy.
+
+By default, it will look for `model.gguf` and `config.json`, but you can specify
+custom local or remote `weight-file` and `config-file`s:
+
+```bash
+cargo run --example quantized-t5 --release -- \
+ --model-id "jbochi/candle-coedit-quantized" \
+ --weight-file "model-xl.gguf" \
+ --config-file "config-xl.json" \
+ --prompt "Rewrite to make this easier to understand: Note that a storm surge is what forecasters consider a hurricane's most treacherous aspect." \
+ --temperature 0
+...
+ Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
```
diff --git a/candle-wasm-examples/t5/index.html b/candle-wasm-examples/t5/index.html
index 227b723a..2c9a6f35 100644
--- a/candle-wasm-examples/t5/index.html
+++ b/candle-wasm-examples/t5/index.html
@@ -166,13 +166,19 @@
target="_blank"
class="link"
>flan-t5-small</a
- >
- and several t5
+ >,
+ several
<a
href="https://huggingface.co/lmz/candle-quantized-t5/tree/main"
target="_blank"
class="link">
- t5 quantized gguf</a
+ t5 quantized gguf models</a
+ >, and also a quantized
+ <a
+ href="https://huggingface.co/jbochi/candle-coedit-quantized/tree/main"
+ target="_blank"
+ class="link">
+ CoEdIT model for text rewrite</a
>.
</p>
</div>
diff --git a/candle-wasm-examples/t5/utils.js b/candle-wasm-examples/t5/utils.js
index 851d1b76..20b0a792 100644
--- a/candle-wasm-examples/t5/utils.js
+++ b/candle-wasm-examples/t5/utils.js
@@ -65,6 +65,7 @@ export async function generateText(
worker.addEventListener("message", messageHandler);
});
}
+
export const MODELS = {
t5_small_quantized: {
size: "64.4 MB",
@@ -133,7 +134,6 @@ export const MODELS = {
summarization: { prefix: "summarize: ", max_length: 200 },
},
},
-
flan_t5_base_quantized: {
size: "263 MB",
base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/",
@@ -156,7 +156,41 @@ export const MODELS = {
summarization: { prefix: "summarize: ", max_length: 200 },
},
},
+ coedit_large_quantized: {
+ size: "643 MB",
+ base_url: "https://huggingface.co/jbochi/candle-coedit-quantized/resolve/main/",
+ model: "model.gguf",
+ tokenizer: "tokenizer.json",
+ config: "config.json",
+ tasks: {
+ fluency: {
+ prefix: "Fix the grammar: ",
+ max_length: 300,
+ },
+ coherence: {
+ prefix: "Rewrite to make this easier to understand: ",
+ max_length: 300,
+ },
+ simplification: {
+ prefix: "translate English to Romanian: ",
+ max_length: 300,
+ },
+ simplification: {
+ prefix: "Paraphrase this: ",
+ max_length: 300,
+ },
+ formalization: {
+ prefix: "Write this more formally: ",
+ max_length: 300,
+ },
+ neutralize: {
+ prefix: "Write in a more neutral way: ",
+ max_length: 300,
+ },
+ },
+ },
};
+
export function getModelInfo(id, taskID) {
const model = MODELS[id];
return {