# This example shows how the candle Python api can be used to replicate llama.cpp. import os import sys # The "import candle" statement below works if there is a "candle.so" file in sys.path. # Here we check for shared libraries that can be used in the build directory. BUILD_DIR = "./target/release-with-debug" so_file = BUILD_DIR + "/candle.so" if os.path.islink(so_file): os.remove(so_file) for lib_file in ["libcandle.dylib", "libcandle.so"]: lib_file_ = BUILD_DIR + "/" + lib_file if os.path.isfile(lib_file_): os.symlink(lib_file, so_file) sys.path.insert(0, BUILD_DIR) break import candle MAX_SEQ_LEN = 4096 def masked_fill(on_false, mask, on_true): shape = mask.shape on_true = candle.tensor(on_true).broadcast_as(shape) return mask.where_cond(on_true, on_false) class RmsNorm: def __init__(self, qtensor): self.weight = qtensor.dequantize() def __call__(self, x): b_size, seq_len, hidden_size = x.shape norm_x = x.sqr().sum_keepdim(2) / hidden_size x_normed = x.broadcast_div((norm_x + 1e-5).sqrt()) return x_normed.broadcast_mul(self.weight) class QuantizedLayer: def __init__(self, layer_idx, hparams, all_tensors, cos_sin): p = f"layers.{layer_idx}" self.attention_wq = all_tensors[f"{p}.attention.wq.weight"] self.attention_wk = all_tensors[f"{p}.attention.wk.weight"] self.attention_wv = all_tensors[f"{p}.attention.wv.weight"] self.attention_wo = all_tensors[f"{p}.attention.wo.weight"] self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"] self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"] self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"] self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"]) self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"]) self.n_head = hparams["n_head"] self.n_kv_head = self.n_head self.head_dim = hparams["n_embd"] // self.n_head self.kv_cache = None self.cos = cos_sin[0] self.sin = cos_sin[1] def __call__(self, x, mask, index_pos): residual = x x = self.attn_norm(x) attn = self.forward_attn(x, mask, index_pos) x = attn + residual residual = x x = self.ffn_norm(x) w1 = self.ffw1.matmul_t(x) w3 = self.ffw3.matmul_t(x) mlp = self.ffw2.matmul_t(candle.nn.silu(w1) * w3) return mlp + residual def forward_attn(self, x, mask, index_pos): b_size, seq_len, n_embd = x.shape q = self.attention_wq.matmul_t(x) k = self.attention_wk.matmul_t(x) v = self.attention_wv.matmul_t(x) q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2) k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) q = self.apply_rotary_emb(q, index_pos) k = self.apply_rotary_emb(k, index_pos) if self.kv_cache is not None and index_pos > 0: prev_k, prev_v = self.kv_cache k = candle.cat([prev_k, k], 2).contiguous() v = candle.cat([prev_v, v], 2).contiguous() self.kv_cache = (k, v) # TODO: maybe repeat k/v here if we start supporting MQA. att = q.matmul(k.t()) / self.head_dim**0.5 mask = mask.broadcast_as(att.shape) att = masked_fill(att, mask, float("-inf")) att = candle.nn.softmax(att, -1) y = att.matmul(v.contiguous()) y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd)) return self.attention_wo.matmul_t(y) def apply_rotary_emb(self, x, index_pos): (b_size, n_head, seq_len, n_embd) = x.shape cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2)) x0 = x.narrow(-1, 0, 1) x1 = x.narrow(-1, 1, 1) y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin) y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos) rope = candle.cat([y0, y1], -1) return rope.flatten_from(-2) def precompute_freqs_cis(hparams, freq_base): head_dim = hparams["n_embd"] // hparams["n_head"] theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)] theta = candle.tensor(theta) idx_theta = [float(i) for i in range(MAX_SEQ_LEN)] idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1)) m = idx_theta.matmul(theta.unsqueeze(0)) print(m.shape) return (m.cos(), m.sin()) class QuantizedLlama: def __init__(self, hparams, all_tensors): self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize() self.norm = RmsNorm(all_tensors["norm.weight"]) self.output = all_tensors["output.weight"] self.layers = [] cos_sin = precompute_freqs_cis(hparams, 10000.) for layer_idx in range(hparams["n_layer"]): layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin) self.layers.append(layer) def __call__(self, token, index_pos): b_size, seq_len = token.shape vocab_size, hidden_size = self.tok_embeddings.shape token = token.reshape((b_size * seq_len,)) x = self.tok_embeddings.index_select(token, 0) x = x.reshape((b_size, seq_len, hidden_size)) mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)] mask = candle.tensor(mask).reshape((seq_len, seq_len)) for layer in self.layers: x = layer(x, mask, index_pos) return x def main(): if len(sys.argv) < 2: raise ValueError("missing weight file argument") filename = sys.argv[1] if filename.endswith("gguf"): all_tensors = candle.load_gguf(sys.argv[1]) hparams = None else: all_tensors, hparams = candle.load_ggml(sys.argv[1]) print(hparams) model = QuantizedLlama(hparams, all_tensors) tokens = [1] for token_idx in range(1): print(tokens) last_token = tokens[-1] lt = candle.tensor([last_token]).unsqueeze(0) logits = model(lt, len(tokens)) print(logits) next_token = "TODO: sample" tokens.append(next_token) if __name__ == '__main__': main()