summaryrefslogtreecommitdiff
path: root/candle-pyo3/quant-llama.py
blob: 1cb39e4ff2f9e61df9c911d4cf6f90685b3bef43 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# This example shows how the candle Python api can be used to replicate llama.cpp.
import sys
from typing import Dict, Tuple, Any
import candle
from candle.models.llama import QuantizedLlama
from candle import utils

MAX_SEQ_LEN = 4096


def gguf_rename(tensor_name: str):
    if tensor_name == "token_embd.weight":
        return "tok_embeddings.weight"
    if tensor_name == "output_norm.weight":
        return "norm.weight"
    tensor_name = tensor_name.replace("blk.", "layers.")
    tensor_name = tensor_name.replace(".attn_q.", ".attention.wq.")
    tensor_name = tensor_name.replace(".attn_k.", ".attention.wk.")
    tensor_name = tensor_name.replace(".attn_v.", ".attention.wv.")
    tensor_name = tensor_name.replace(".attn_output.", ".attention.wo.")
    tensor_name = tensor_name.replace(".ffn_gate.", ".feed_forward.w1.")
    tensor_name = tensor_name.replace(".ffn_down.", ".feed_forward.w2.")
    tensor_name = tensor_name.replace(".ffn_up.", ".feed_forward.w3.")
    tensor_name = tensor_name.replace(".attn_norm.", ".attention_norm.")
    return tensor_name


def main():
    if len(sys.argv) < 2:
        raise ValueError("missing weight file argument")

    filename = sys.argv[1]
    print(f"reading model file {filename}")
    if filename.endswith("gguf"):
        all_tensors, metadata = utils.load_gguf(filename)
        vocab = metadata["tokenizer.ggml.tokens"]
        for i, v in enumerate(vocab):
            vocab[i] = "\n" if v == "<0x0A>" else v.replace("▁", " ")
        hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
        print(hparams)
        hparams = {
            "n_vocab": len(vocab),
            "n_embd": metadata["llama.embedding_length"],
            "n_mult": 256,
            "n_head": metadata["llama.attention.head_count"],
            "n_head_kv": metadata["llama.attention.head_count_kv"],
            "n_layer": metadata["llama.block_count"],
            "n_rot": metadata["llama.rope.dimension_count"],
            "rope_freq": metadata.get("llama.rope.freq_base", 10000.0),
            "ftype": metadata["general.file_type"],
            "context_length": metadata["llama.context_length"],
        }
        all_tensors = {gguf_rename(k): v for k, v in all_tensors.items()}
    else:
        all_tensors, hparams, vocab = utils.load_ggml(filename)
        hparams["context_length"] = 2048

    print(hparams)
    model = QuantizedLlama(hparams, all_tensors)
    print("model built, starting inference")

    tokens = [1]
    for token_idx in range(500):
        last_token = tokens[-1]
        lt = candle.tensor([last_token]).unsqueeze(0)
        logits = model.forward(lt, len(tokens))
        # Greedy sampling for now
        # pr = candle.nn.softmax(logits, -1)
        m = logits.get(0).argmax_keepdim(-1)
        next_token = m.values()[0]
        print(vocab[next_token], end="", flush=True)
        tokens.append(next_token)


if __name__ == "__main__":
    main()