diff options
Diffstat (limited to 'candle-pyo3/quant-llama.py')
-rw-r--r-- | candle-pyo3/quant-llama.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py index a3638855..092c1faa 100644 --- a/candle-pyo3/quant-llama.py +++ b/candle-pyo3/quant-llama.py @@ -117,7 +117,6 @@ def precompute_freqs_cis(hparams, freq_base): idx_theta = [float(i) for i in range(MAX_SEQ_LEN)] idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1)) m = idx_theta.matmul(theta.unsqueeze(0)) - print(m.shape) return (m.cos(), m.sin()) class QuantizedLlama: @@ -143,28 +142,36 @@ class QuantizedLlama: for layer in self.layers: x = layer(x, mask, index_pos) + x = self.norm(x) + x = x.narrow(1, -1, 1).squeeze(1) + x = self.output.matmul_t(x) return x def main(): if len(sys.argv) < 2: raise ValueError("missing weight file argument") filename = sys.argv[1] + print(f"reading model file {filename}") if filename.endswith("gguf"): all_tensors = candle.load_gguf(sys.argv[1]) hparams = None + vocab = None else: - all_tensors, hparams = candle.load_ggml(sys.argv[1]) + all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1]) print(hparams) model = QuantizedLlama(hparams, all_tensors) + print("model built, starting inference") tokens = [1] - for token_idx in range(1): - print(tokens) + for token_idx in range(500): last_token = tokens[-1] lt = candle.tensor([last_token]).unsqueeze(0) logits = model(lt, len(tokens)) - print(logits) - next_token = "TODO: sample" + # Greedy sampling for now + # pr = candle.nn.softmax(logits, -1) + m = logits.get(0).argmax_keepdim(-1) + next_token = m.values()[0] + print(vocab[next_token], end='', flush=True) tokens.append(next_token) if __name__ == '__main__': |