summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper/extract_weights.py
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-07-04 12:03:28 +0100
committerlaurent <laurent.mazare@gmail.com>2023-07-04 12:03:28 +0100
commitaea090401ded5789e95f1f8efb7404a66b508356 (patch)
treeb327653cf20da250797e45662c44ba11cbd1ea39 /candle-examples/examples/whisper/extract_weights.py
parent950b4af49e56b640b87eb273e839b2fd466e1424 (diff)
downloadcandle-aea090401ded5789e95f1f8efb7404a66b508356.tar.gz
candle-aea090401ded5789e95f1f8efb7404a66b508356.tar.bz2
candle-aea090401ded5789e95f1f8efb7404a66b508356.zip
Run the text decoding bit.
Diffstat (limited to 'candle-examples/examples/whisper/extract_weights.py')
-rw-r--r--candle-examples/examples/whisper/extract_weights.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-examples/examples/whisper/extract_weights.py b/candle-examples/examples/whisper/extract_weights.py
index d6ccffc6..65602703 100644
--- a/candle-examples/examples/whisper/extract_weights.py
+++ b/candle-examples/examples/whisper/extract_weights.py
@@ -8,6 +8,6 @@ data = torch.load("tiny.en.pt")
weights = {}
for k, v in data["model_state_dict"].items():
weights[k] = v.contiguous()
- print(k, v.shape)
+ print(k, v.shape, v.dtype)
save_file(weights, "tiny.en.safetensors")
print(data["dims"])