summaryrefslogtreecommitdiff
path: root/candle-pyo3/quant-llama.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/quant-llama.py')
-rw-r--r--candle-pyo3/quant-llama.py31
1 files changed, 16 insertions, 15 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
index 020d525d..46d9ff62 100644
--- a/candle-pyo3/quant-llama.py
+++ b/candle-pyo3/quant-llama.py
@@ -1,27 +1,28 @@
# This example shows how the candle Python api can be used to replicate llama.cpp.
import sys
+from typing import Dict, Tuple, Any
import candle
-from candle.utils import load_ggml,load_gguf
+from candle import Tensor, QTensor, utils, nn
MAX_SEQ_LEN = 4096
-def masked_fill(on_false, mask, on_true):
+def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor):
shape = mask.shape
on_true = candle.tensor(on_true).broadcast_as(shape)
return mask.where_cond(on_true, on_false)
class RmsNorm:
- def __init__(self, qtensor):
+ def __init__(self, qtensor:QTensor):
self.weight = qtensor.dequantize()
- def __call__(self, x):
+ def __call__(self, x:Tensor):
b_size, seq_len, hidden_size = x.shape
norm_x = x.sqr().sum_keepdim(2) / hidden_size
x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
return x_normed.broadcast_mul(self.weight)
class QuantizedLayer:
- def __init__(self, layer_idx, hparams, all_tensors, cos_sin):
+ def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]):
p = f"layers.{layer_idx}"
self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
@@ -41,7 +42,7 @@ class QuantizedLayer:
self.cos = cos_sin[0]
self.sin = cos_sin[1]
- def __call__(self, x, mask, index_pos):
+ def __call__(self, x:Tensor, mask:Tensor, index_pos:int):
residual = x
x = self.attn_norm(x)
attn = self.forward_attn(x, mask, index_pos)
@@ -51,11 +52,11 @@ class QuantizedLayer:
x = self.ffn_norm(x)
w1 = self.ffw1.matmul_t(x)
w3 = self.ffw3.matmul_t(x)
- mlp = self.ffw2.matmul_t(candle.nn.silu(w1) * w3)
+ mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
return mlp + residual
- def forward_attn(self, x, mask, index_pos):
+ def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int):
b_size, seq_len, n_embd = x.shape
q = self.attention_wq.matmul_t(x)
k = self.attention_wk.matmul_t(x)
@@ -80,12 +81,12 @@ class QuantizedLayer:
att = q.matmul(k.t()) / self.head_dim**0.5
mask = mask.broadcast_as(att.shape)
att = masked_fill(att, mask, float("-inf"))
- att = candle.nn.softmax(att, -1)
+ att = nn.softmax(att, -1)
y = att.matmul(v.contiguous())
y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
return self.attention_wo.matmul_t(y)
- def apply_rotary_emb(self, x, index_pos):
+ def apply_rotary_emb(self, x:Tensor, index_pos:int):
(b_size, n_head, seq_len, n_embd) = x.shape
cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
@@ -107,7 +108,7 @@ def precompute_freqs_cis(hparams, freq_base):
return (m.cos(), m.sin())
class QuantizedLlama:
- def __init__(self, hparams, all_tensors):
+ def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]):
self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
self.norm = RmsNorm(all_tensors["norm.weight"])
self.output = all_tensors["output.weight"]
@@ -118,7 +119,7 @@ class QuantizedLlama:
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
self.layers.append(layer)
- def __call__(self, token, index_pos):
+ def __call__(self, token:Tensor, index_pos:int):
b_size, seq_len = token.shape
vocab_size, hidden_size = self.tok_embeddings.shape
token = token.reshape((b_size * seq_len,))
@@ -135,7 +136,7 @@ class QuantizedLlama:
x = self.output.matmul_t(x)
return x
-def gguf_rename(tensor_name):
+def gguf_rename(tensor_name:str):
if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight'
if tensor_name == 'output_norm.weight': return 'norm.weight'
tensor_name = tensor_name.replace('blk.', 'layers.')
@@ -155,7 +156,7 @@ def main():
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
- all_tensors, metadata = load_gguf(sys.argv[1])
+ all_tensors, metadata = utils.load_gguf(sys.argv[1])
vocab = metadata["tokenizer.ggml.tokens"]
for i, v in enumerate(vocab):
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
@@ -175,7 +176,7 @@ def main():
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
else:
- all_tensors, hparams, vocab = load_ggml(sys.argv[1])
+ all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
print("model built, starting inference")