diff options
Diffstat (limited to 'candle-pyo3/quant-llama.py')
-rw-r--r-- | candle-pyo3/quant-llama.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py index 0f7a51c6..020d525d 100644 --- a/candle-pyo3/quant-llama.py +++ b/candle-pyo3/quant-llama.py @@ -1,6 +1,7 @@ # This example shows how the candle Python api can be used to replicate llama.cpp. import sys import candle +from candle.utils import load_ggml,load_gguf MAX_SEQ_LEN = 4096 @@ -154,7 +155,7 @@ def main(): filename = sys.argv[1] print(f"reading model file {filename}") if filename.endswith("gguf"): - all_tensors, metadata = candle.load_gguf(sys.argv[1]) + all_tensors, metadata = load_gguf(sys.argv[1]) vocab = metadata["tokenizer.ggml.tokens"] for i, v in enumerate(vocab): vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ') @@ -168,13 +169,13 @@ def main(): 'n_head_kv': metadata['llama.attention.head_count_kv'], 'n_layer': metadata['llama.block_count'], 'n_rot': metadata['llama.rope.dimension_count'], - 'rope_freq': metadata['llama.rope.freq_base'], + 'rope_freq': metadata.get('llama.rope.freq_base', 10000.), 'ftype': metadata['general.file_type'], } all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() } else: - all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1]) + all_tensors, hparams, vocab = load_ggml(sys.argv[1]) print(hparams) model = QuantizedLlama(hparams, all_tensors) print("model built, starting inference") |