summaryrefslogtreecommitdiff
path: root/candle-pyo3/quant-llama.py
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-04 08:07:00 +0200
committerGitHub <noreply@github.com>2023-09-04 07:07:00 +0100
commit20512ba408f9840828e902b7dd824be5a0969feb (patch)
treed525df2e607b60e130a32bd8bcf62a01fb4be47f /candle-pyo3/quant-llama.py
parent9c61b0fc9b9062b347c176b5f0f86b97b6804a1b (diff)
downloadcandle-20512ba408f9840828e902b7dd824be5a0969feb.tar.gz
candle-20512ba408f9840828e902b7dd824be5a0969feb.tar.bz2
candle-20512ba408f9840828e902b7dd824be5a0969feb.zip
Return the metadata in the gguf pyo3 bindings. (#729)
* Return the metadata in the gguf pyo3 bindings. * Read the metadata in the quantized llama example. * Get inference to work on gguf files.
Diffstat (limited to 'candle-pyo3/quant-llama.py')
-rw-r--r--candle-pyo3/quant-llama.py39
1 files changed, 35 insertions, 4 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
index 7d74c25e..0f7a51c6 100644
--- a/candle-pyo3/quant-llama.py
+++ b/candle-pyo3/quant-llama.py
@@ -111,7 +111,8 @@ class QuantizedLlama:
self.norm = RmsNorm(all_tensors["norm.weight"])
self.output = all_tensors["output.weight"]
self.layers = []
- cos_sin = precompute_freqs_cis(hparams, 10000.)
+ rope_freq = hparams.get("rope_freq", 10000.)
+ cos_sin = precompute_freqs_cis(hparams, rope_freq)
for layer_idx in range(hparams["n_layer"]):
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
self.layers.append(layer)
@@ -133,15 +134,45 @@ class QuantizedLlama:
x = self.output.matmul_t(x)
return x
+def gguf_rename(tensor_name):
+ if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight'
+ if tensor_name == 'output_norm.weight': return 'norm.weight'
+ tensor_name = tensor_name.replace('blk.', 'layers.')
+ tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.')
+ tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.')
+ tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.')
+ tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.')
+ tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.')
+ tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.')
+ tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.')
+ tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.')
+ return tensor_name
+
def main():
if len(sys.argv) < 2:
raise ValueError("missing weight file argument")
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
- all_tensors = candle.load_gguf(sys.argv[1])
- hparams = None
- vocab = None
+ all_tensors, metadata = candle.load_gguf(sys.argv[1])
+ vocab = metadata["tokenizer.ggml.tokens"]
+ for i, v in enumerate(vocab):
+ vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
+ hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
+ print(hparams)
+ hparams = {
+ 'n_vocab': len(vocab),
+ 'n_embd': metadata['llama.embedding_length'],
+ 'n_mult': 256,
+ 'n_head': metadata['llama.attention.head_count'],
+ 'n_head_kv': metadata['llama.attention.head_count_kv'],
+ 'n_layer': metadata['llama.block_count'],
+ 'n_rot': metadata['llama.rope.dimension_count'],
+ 'rope_freq': metadata['llama.rope.freq_base'],
+ 'ftype': metadata['general.file_type'],
+ }
+ all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
+
else:
all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
print(hparams)