summaryrefslogtreecommitdiff
path: root/candle-pyo3/quant-llama.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/quant-llama.py')
-rw-r--r--candle-pyo3/quant-llama.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
index a3638855..092c1faa 100644
--- a/candle-pyo3/quant-llama.py
+++ b/candle-pyo3/quant-llama.py
@@ -117,7 +117,6 @@ def precompute_freqs_cis(hparams, freq_base):
idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
m = idx_theta.matmul(theta.unsqueeze(0))
- print(m.shape)
return (m.cos(), m.sin())
class QuantizedLlama:
@@ -143,28 +142,36 @@ class QuantizedLlama:
for layer in self.layers:
x = layer(x, mask, index_pos)
+ x = self.norm(x)
+ x = x.narrow(1, -1, 1).squeeze(1)
+ x = self.output.matmul_t(x)
return x
def main():
if len(sys.argv) < 2:
raise ValueError("missing weight file argument")
filename = sys.argv[1]
+ print(f"reading model file {filename}")
if filename.endswith("gguf"):
all_tensors = candle.load_gguf(sys.argv[1])
hparams = None
+ vocab = None
else:
- all_tensors, hparams = candle.load_ggml(sys.argv[1])
+ all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
+ print("model built, starting inference")
tokens = [1]
- for token_idx in range(1):
- print(tokens)
+ for token_idx in range(500):
last_token = tokens[-1]
lt = candle.tensor([last_token]).unsqueeze(0)
logits = model(lt, len(tokens))
- print(logits)
- next_token = "TODO: sample"
+ # Greedy sampling for now
+ # pr = candle.nn.softmax(logits, -1)
+ m = logits.get(0).argmax_keepdim(-1)
+ next_token = m.values()[0]
+ print(vocab[next_token], end='', flush=True)
tokens.append(next_token)
if __name__ == '__main__':