summaryrefslogtreecommitdiff
path: root/candle-pyo3/quant-llama.py
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-02 14:41:48 +0200
committerGitHub <noreply@github.com>2023-09-02 13:41:48 +0100
commitad796eb4be9877712c0034d291a082cee1fd2dec (patch)
tree8a7e9c5aa0a3d607c3352ceb47ea0cb3958aef28 /candle-pyo3/quant-llama.py
parente8e33752f4562d69cbef3de61d02676da112dfb8 (diff)
downloadcandle-ad796eb4be9877712c0034d291a082cee1fd2dec.tar.gz
candle-ad796eb4be9877712c0034d291a082cee1fd2dec.tar.bz2
candle-ad796eb4be9877712c0034d291a082cee1fd2dec.zip
More quantized llama in python. (#716)
* More quantized llama in python. * Expose a couple more functions. * Apply the last layer. * Use the vocab from the ggml files.
Diffstat (limited to 'candle-pyo3/quant-llama.py')
-rw-r--r--candle-pyo3/quant-llama.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
index a3638855..092c1faa 100644
--- a/candle-pyo3/quant-llama.py
+++ b/candle-pyo3/quant-llama.py
@@ -117,7 +117,6 @@ def precompute_freqs_cis(hparams, freq_base):
idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
m = idx_theta.matmul(theta.unsqueeze(0))
- print(m.shape)
return (m.cos(), m.sin())
class QuantizedLlama:
@@ -143,28 +142,36 @@ class QuantizedLlama:
for layer in self.layers:
x = layer(x, mask, index_pos)
+ x = self.norm(x)
+ x = x.narrow(1, -1, 1).squeeze(1)
+ x = self.output.matmul_t(x)
return x
def main():
if len(sys.argv) < 2:
raise ValueError("missing weight file argument")
filename = sys.argv[1]
+ print(f"reading model file {filename}")
if filename.endswith("gguf"):
all_tensors = candle.load_gguf(sys.argv[1])
hparams = None
+ vocab = None
else:
- all_tensors, hparams = candle.load_ggml(sys.argv[1])
+ all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
+ print("model built, starting inference")
tokens = [1]
- for token_idx in range(1):
- print(tokens)
+ for token_idx in range(500):
last_token = tokens[-1]
lt = candle.tensor([last_token]).unsqueeze(0)
logits = model(lt, len(tokens))
- print(logits)
- next_token = "TODO: sample"
+ # Greedy sampling for now
+ # pr = candle.nn.softmax(logits, -1)
+ m = logits.get(0).argmax_keepdim(-1)
+ next_token = m.values()[0]
+ print(vocab[next_token], end='', flush=True)
tokens.append(next_token)
if __name__ == '__main__':