diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-02 14:41:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-02 13:41:48 +0100 |
commit | ad796eb4be9877712c0034d291a082cee1fd2dec (patch) | |
tree | 8a7e9c5aa0a3d607c3352ceb47ea0cb3958aef28 /candle-pyo3/quant-llama.py | |
parent | e8e33752f4562d69cbef3de61d02676da112dfb8 (diff) | |
download | candle-ad796eb4be9877712c0034d291a082cee1fd2dec.tar.gz candle-ad796eb4be9877712c0034d291a082cee1fd2dec.tar.bz2 candle-ad796eb4be9877712c0034d291a082cee1fd2dec.zip |
More quantized llama in python. (#716)
* More quantized llama in python.
* Expose a couple more functions.
* Apply the last layer.
* Use the vocab from the ggml files.
Diffstat (limited to 'candle-pyo3/quant-llama.py')
-rw-r--r-- | candle-pyo3/quant-llama.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py index a3638855..092c1faa 100644 --- a/candle-pyo3/quant-llama.py +++ b/candle-pyo3/quant-llama.py @@ -117,7 +117,6 @@ def precompute_freqs_cis(hparams, freq_base): idx_theta = [float(i) for i in range(MAX_SEQ_LEN)] idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1)) m = idx_theta.matmul(theta.unsqueeze(0)) - print(m.shape) return (m.cos(), m.sin()) class QuantizedLlama: @@ -143,28 +142,36 @@ class QuantizedLlama: for layer in self.layers: x = layer(x, mask, index_pos) + x = self.norm(x) + x = x.narrow(1, -1, 1).squeeze(1) + x = self.output.matmul_t(x) return x def main(): if len(sys.argv) < 2: raise ValueError("missing weight file argument") filename = sys.argv[1] + print(f"reading model file {filename}") if filename.endswith("gguf"): all_tensors = candle.load_gguf(sys.argv[1]) hparams = None + vocab = None else: - all_tensors, hparams = candle.load_ggml(sys.argv[1]) + all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1]) print(hparams) model = QuantizedLlama(hparams, all_tensors) + print("model built, starting inference") tokens = [1] - for token_idx in range(1): - print(tokens) + for token_idx in range(500): last_token = tokens[-1] lt = candle.tensor([last_token]).unsqueeze(0) logits = model(lt, len(tokens)) - print(logits) - next_token = "TODO: sample" + # Greedy sampling for now + # pr = candle.nn.softmax(logits, -1) + m = logits.get(0).argmax_keepdim(-1) + next_token = m.values()[0] + print(vocab[next_token], end='', flush=True) tokens.append(next_token) if __name__ == '__main__': |