diff options
Diffstat (limited to 'candle-pyo3/quant-llama.py')
-rw-r--r-- | candle-pyo3/quant-llama.py | 197 |
1 files changed, 38 insertions, 159 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py index 46d9ff62..1cb39e4f 100644 --- a/candle-pyo3/quant-llama.py +++ b/candle-pyo3/quant-llama.py @@ -2,181 +2,59 @@ import sys from typing import Dict, Tuple, Any import candle -from candle import Tensor, QTensor, utils, nn +from candle.models.llama import QuantizedLlama +from candle import utils MAX_SEQ_LEN = 4096 -def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor): - shape = mask.shape - on_true = candle.tensor(on_true).broadcast_as(shape) - return mask.where_cond(on_true, on_false) -class RmsNorm: - def __init__(self, qtensor:QTensor): - self.weight = qtensor.dequantize() - - def __call__(self, x:Tensor): - b_size, seq_len, hidden_size = x.shape - norm_x = x.sqr().sum_keepdim(2) / hidden_size - x_normed = x.broadcast_div((norm_x + 1e-5).sqrt()) - return x_normed.broadcast_mul(self.weight) - -class QuantizedLayer: - def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]): - p = f"layers.{layer_idx}" - self.attention_wq = all_tensors[f"{p}.attention.wq.weight"] - self.attention_wk = all_tensors[f"{p}.attention.wk.weight"] - self.attention_wv = all_tensors[f"{p}.attention.wv.weight"] - self.attention_wo = all_tensors[f"{p}.attention.wo.weight"] - self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"] - self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"] - self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"] - self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"]) - self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"]) - - self.n_head = hparams["n_head"] - self.n_kv_head = self.n_head - self.head_dim = hparams["n_embd"] // self.n_head - - self.kv_cache = None - self.cos = cos_sin[0] - self.sin = cos_sin[1] - - def __call__(self, x:Tensor, mask:Tensor, index_pos:int): - residual = x - x = self.attn_norm(x) - attn = self.forward_attn(x, mask, index_pos) - x = attn + residual - - residual = x - x = self.ffn_norm(x) - w1 = self.ffw1.matmul_t(x) - w3 = self.ffw3.matmul_t(x) - mlp = self.ffw2.matmul_t(nn.silu(w1) * w3) - - return mlp + residual - - def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int): - b_size, seq_len, n_embd = x.shape - q = self.attention_wq.matmul_t(x) - k = self.attention_wk.matmul_t(x) - v = self.attention_wv.matmul_t(x) - - q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2) - k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) - v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) - - q = self.apply_rotary_emb(q, index_pos) - k = self.apply_rotary_emb(k, index_pos) - - if self.kv_cache is not None and index_pos > 0: - prev_k, prev_v = self.kv_cache - k = candle.cat([prev_k, k], 2).contiguous() - v = candle.cat([prev_v, v], 2).contiguous() - - self.kv_cache = (k, v) - - # TODO: maybe repeat k/v here if we start supporting MQA. - - att = q.matmul(k.t()) / self.head_dim**0.5 - mask = mask.broadcast_as(att.shape) - att = masked_fill(att, mask, float("-inf")) - att = nn.softmax(att, -1) - y = att.matmul(v.contiguous()) - y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd)) - return self.attention_wo.matmul_t(y) - - def apply_rotary_emb(self, x:Tensor, index_pos:int): - (b_size, n_head, seq_len, n_embd) = x.shape - cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) - sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) - x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2)) - x0 = x.narrow(-1, 0, 1) - x1 = x.narrow(-1, 1, 1) - y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin) - y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos) - rope = candle.cat([y0, y1], -1) - return rope.flatten_from(-2) - -def precompute_freqs_cis(hparams, freq_base): - head_dim = hparams["n_embd"] // hparams["n_head"] - theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)] - theta = candle.tensor(theta) - idx_theta = [float(i) for i in range(MAX_SEQ_LEN)] - idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1)) - m = idx_theta.matmul(theta.unsqueeze(0)) - return (m.cos(), m.sin()) - -class QuantizedLlama: - def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]): - self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize() - self.norm = RmsNorm(all_tensors["norm.weight"]) - self.output = all_tensors["output.weight"] - self.layers = [] - rope_freq = hparams.get("rope_freq", 10000.) - cos_sin = precompute_freqs_cis(hparams, rope_freq) - for layer_idx in range(hparams["n_layer"]): - layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin) - self.layers.append(layer) - - def __call__(self, token:Tensor, index_pos:int): - b_size, seq_len = token.shape - vocab_size, hidden_size = self.tok_embeddings.shape - token = token.reshape((b_size * seq_len,)) - x = self.tok_embeddings.index_select(token, 0) - x = x.reshape((b_size, seq_len, hidden_size)) - - mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)] - mask = candle.tensor(mask).reshape((seq_len, seq_len)) - - for layer in self.layers: - x = layer(x, mask, index_pos) - x = self.norm(x) - x = x.narrow(1, -1, 1).squeeze(1) - x = self.output.matmul_t(x) - return x - -def gguf_rename(tensor_name:str): - if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight' - if tensor_name == 'output_norm.weight': return 'norm.weight' - tensor_name = tensor_name.replace('blk.', 'layers.') - tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.') - tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.') - tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.') - tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.') - tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.') - tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.') - tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.') - tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.') +def gguf_rename(tensor_name: str): + if tensor_name == "token_embd.weight": + return "tok_embeddings.weight" + if tensor_name == "output_norm.weight": + return "norm.weight" + tensor_name = tensor_name.replace("blk.", "layers.") + tensor_name = tensor_name.replace(".attn_q.", ".attention.wq.") + tensor_name = tensor_name.replace(".attn_k.", ".attention.wk.") + tensor_name = tensor_name.replace(".attn_v.", ".attention.wv.") + tensor_name = tensor_name.replace(".attn_output.", ".attention.wo.") + tensor_name = tensor_name.replace(".ffn_gate.", ".feed_forward.w1.") + tensor_name = tensor_name.replace(".ffn_down.", ".feed_forward.w2.") + tensor_name = tensor_name.replace(".ffn_up.", ".feed_forward.w3.") + tensor_name = tensor_name.replace(".attn_norm.", ".attention_norm.") return tensor_name + def main(): if len(sys.argv) < 2: raise ValueError("missing weight file argument") + filename = sys.argv[1] print(f"reading model file {filename}") if filename.endswith("gguf"): - all_tensors, metadata = utils.load_gguf(sys.argv[1]) + all_tensors, metadata = utils.load_gguf(filename) vocab = metadata["tokenizer.ggml.tokens"] for i, v in enumerate(vocab): - vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ') + vocab[i] = "\n" if v == "<0x0A>" else v.replace("▁", " ") hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")} print(hparams) hparams = { - 'n_vocab': len(vocab), - 'n_embd': metadata['llama.embedding_length'], - 'n_mult': 256, - 'n_head': metadata['llama.attention.head_count'], - 'n_head_kv': metadata['llama.attention.head_count_kv'], - 'n_layer': metadata['llama.block_count'], - 'n_rot': metadata['llama.rope.dimension_count'], - 'rope_freq': metadata.get('llama.rope.freq_base', 10000.), - 'ftype': metadata['general.file_type'], + "n_vocab": len(vocab), + "n_embd": metadata["llama.embedding_length"], + "n_mult": 256, + "n_head": metadata["llama.attention.head_count"], + "n_head_kv": metadata["llama.attention.head_count_kv"], + "n_layer": metadata["llama.block_count"], + "n_rot": metadata["llama.rope.dimension_count"], + "rope_freq": metadata.get("llama.rope.freq_base", 10000.0), + "ftype": metadata["general.file_type"], + "context_length": metadata["llama.context_length"], } - all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() } - + all_tensors = {gguf_rename(k): v for k, v in all_tensors.items()} else: - all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1]) + all_tensors, hparams, vocab = utils.load_ggml(filename) + hparams["context_length"] = 2048 + print(hparams) model = QuantizedLlama(hparams, all_tensors) print("model built, starting inference") @@ -185,13 +63,14 @@ def main(): for token_idx in range(500): last_token = tokens[-1] lt = candle.tensor([last_token]).unsqueeze(0) - logits = model(lt, len(tokens)) + logits = model.forward(lt, len(tokens)) # Greedy sampling for now # pr = candle.nn.softmax(logits, -1) m = logits.get(0).argmax_keepdim(-1) next_token = m.values()[0] - print(vocab[next_token], end='', flush=True) + print(vocab[next_token], end="", flush=True) tokens.append(next_token) -if __name__ == '__main__': + +if __name__ == "__main__": main() |