summaryrefslogtreecommitdiff
path: root/candle-pyo3/quant-llama.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/quant-llama.py')
-rw-r--r--candle-pyo3/quant-llama.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
index 0f7a51c6..020d525d 100644
--- a/candle-pyo3/quant-llama.py
+++ b/candle-pyo3/quant-llama.py
@@ -1,6 +1,7 @@
# This example shows how the candle Python api can be used to replicate llama.cpp.
import sys
import candle
+from candle.utils import load_ggml,load_gguf
MAX_SEQ_LEN = 4096
@@ -154,7 +155,7 @@ def main():
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
- all_tensors, metadata = candle.load_gguf(sys.argv[1])
+ all_tensors, metadata = load_gguf(sys.argv[1])
vocab = metadata["tokenizer.ggml.tokens"]
for i, v in enumerate(vocab):
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
@@ -168,13 +169,13 @@ def main():
'n_head_kv': metadata['llama.attention.head_count_kv'],
'n_layer': metadata['llama.block_count'],
'n_rot': metadata['llama.rope.dimension_count'],
- 'rope_freq': metadata['llama.rope.freq_base'],
+ 'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
'ftype': metadata['general.file_type'],
}
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
else:
- all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
+ all_tensors, hparams, vocab = load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
print("model built, starting inference")