summaryrefslogtreecommitdiff
path: root/candle-pyo3/quant-llama.py
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-02 12:26:05 +0200
committerGitHub <noreply@github.com>2023-09-02 11:26:05 +0100
commite8e33752f4562d69cbef3de61d02676da112dfb8 (patch)
tree9b8b8587aa1bb60eef0ac3a44ef6f4a0f86916cc /candle-pyo3/quant-llama.py
parentdabaa479b966296faad294c40b69d321d51ee4df (diff)
downloadcandle-e8e33752f4562d69cbef3de61d02676da112dfb8.tar.gz
candle-e8e33752f4562d69cbef3de61d02676da112dfb8.tar.bz2
candle-e8e33752f4562d69cbef3de61d02676da112dfb8.zip
Sketch a quantized llama using the pyo3 api. (#715)
* Sketch a quantized llama using the pyo3 api. * Add more ops. * Expose a few more functions to use in the quantized model. * Rope embeddings. * Get the forward pass to work.
Diffstat (limited to 'candle-pyo3/quant-llama.py')
-rw-r--r--candle-pyo3/quant-llama.py171
1 files changed, 171 insertions, 0 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
new file mode 100644
index 00000000..a3638855
--- /dev/null
+++ b/candle-pyo3/quant-llama.py
@@ -0,0 +1,171 @@
+# This example shows how the candle Python api can be used to replicate llama.cpp.
+import os
+import sys
+
+# The "import candle" statement below works if there is a "candle.so" file in sys.path.
+# Here we check for shared libraries that can be used in the build directory.
+BUILD_DIR = "./target/release-with-debug"
+so_file = BUILD_DIR + "/candle.so"
+if os.path.islink(so_file): os.remove(so_file)
+for lib_file in ["libcandle.dylib", "libcandle.so"]:
+ lib_file_ = BUILD_DIR + "/" + lib_file
+ if os.path.isfile(lib_file_):
+ os.symlink(lib_file, so_file)
+ sys.path.insert(0, BUILD_DIR)
+ break
+
+import candle
+
+MAX_SEQ_LEN = 4096
+
+def masked_fill(on_false, mask, on_true):
+ shape = mask.shape
+ on_true = candle.tensor(on_true).broadcast_as(shape)
+ return mask.where_cond(on_true, on_false)
+
+class RmsNorm:
+ def __init__(self, qtensor):
+ self.weight = qtensor.dequantize()
+
+ def __call__(self, x):
+ b_size, seq_len, hidden_size = x.shape
+ norm_x = x.sqr().sum_keepdim(2) / hidden_size
+ x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
+ return x_normed.broadcast_mul(self.weight)
+
+class QuantizedLayer:
+ def __init__(self, layer_idx, hparams, all_tensors, cos_sin):
+ p = f"layers.{layer_idx}"
+ self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
+ self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
+ self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
+ self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
+ self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
+ self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
+ self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
+ self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
+ self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
+
+ self.n_head = hparams["n_head"]
+ self.n_kv_head = self.n_head
+ self.head_dim = hparams["n_embd"] // self.n_head
+
+ self.kv_cache = None
+ self.cos = cos_sin[0]
+ self.sin = cos_sin[1]
+
+ def __call__(self, x, mask, index_pos):
+ residual = x
+ x = self.attn_norm(x)
+ attn = self.forward_attn(x, mask, index_pos)
+ x = attn + residual
+
+ residual = x
+ x = self.ffn_norm(x)
+ w1 = self.ffw1.matmul_t(x)
+ w3 = self.ffw3.matmul_t(x)
+ mlp = self.ffw2.matmul_t(candle.nn.silu(w1) * w3)
+
+ return mlp + residual
+
+ def forward_attn(self, x, mask, index_pos):
+ b_size, seq_len, n_embd = x.shape
+ q = self.attention_wq.matmul_t(x)
+ k = self.attention_wk.matmul_t(x)
+ v = self.attention_wv.matmul_t(x)
+
+ q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
+ k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
+ v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
+
+ q = self.apply_rotary_emb(q, index_pos)
+ k = self.apply_rotary_emb(k, index_pos)
+
+ if self.kv_cache is not None and index_pos > 0:
+ prev_k, prev_v = self.kv_cache
+ k = candle.cat([prev_k, k], 2).contiguous()
+ v = candle.cat([prev_v, v], 2).contiguous()
+
+ self.kv_cache = (k, v)
+
+ # TODO: maybe repeat k/v here if we start supporting MQA.
+
+ att = q.matmul(k.t()) / self.head_dim**0.5
+ mask = mask.broadcast_as(att.shape)
+ att = masked_fill(att, mask, float("-inf"))
+ att = candle.nn.softmax(att, -1)
+ y = att.matmul(v.contiguous())
+ y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
+ return self.attention_wo.matmul_t(y)
+
+ def apply_rotary_emb(self, x, index_pos):
+ (b_size, n_head, seq_len, n_embd) = x.shape
+ cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
+ sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
+ x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2))
+ x0 = x.narrow(-1, 0, 1)
+ x1 = x.narrow(-1, 1, 1)
+ y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
+ y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
+ rope = candle.cat([y0, y1], -1)
+ return rope.flatten_from(-2)
+
+def precompute_freqs_cis(hparams, freq_base):
+ head_dim = hparams["n_embd"] // hparams["n_head"]
+ theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
+ theta = candle.tensor(theta)
+ idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
+ idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
+ m = idx_theta.matmul(theta.unsqueeze(0))
+ print(m.shape)
+ return (m.cos(), m.sin())
+
+class QuantizedLlama:
+ def __init__(self, hparams, all_tensors):
+ self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
+ self.norm = RmsNorm(all_tensors["norm.weight"])
+ self.output = all_tensors["output.weight"]
+ self.layers = []
+ cos_sin = precompute_freqs_cis(hparams, 10000.)
+ for layer_idx in range(hparams["n_layer"]):
+ layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
+ self.layers.append(layer)
+
+ def __call__(self, token, index_pos):
+ b_size, seq_len = token.shape
+ vocab_size, hidden_size = self.tok_embeddings.shape
+ token = token.reshape((b_size * seq_len,))
+ x = self.tok_embeddings.index_select(token, 0)
+ x = x.reshape((b_size, seq_len, hidden_size))
+
+ mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
+ mask = candle.tensor(mask).reshape((seq_len, seq_len))
+
+ for layer in self.layers:
+ x = layer(x, mask, index_pos)
+ return x
+
+def main():
+ if len(sys.argv) < 2:
+ raise ValueError("missing weight file argument")
+ filename = sys.argv[1]
+ if filename.endswith("gguf"):
+ all_tensors = candle.load_gguf(sys.argv[1])
+ hparams = None
+ else:
+ all_tensors, hparams = candle.load_ggml(sys.argv[1])
+ print(hparams)
+ model = QuantizedLlama(hparams, all_tensors)
+
+ tokens = [1]
+ for token_idx in range(1):
+ print(tokens)
+ last_token = tokens[-1]
+ lt = candle.tensor([last_token]).unsqueeze(0)
+ logits = model(lt, len(tokens))
+ print(logits)
+ next_token = "TODO: sample"
+ tokens.append(next_token)
+
+if __name__ == '__main__':
+ main()