summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-02 14:41:48 +0200
committerGitHub <noreply@github.com>2023-09-02 13:41:48 +0100
commitad796eb4be9877712c0034d291a082cee1fd2dec (patch)
tree8a7e9c5aa0a3d607c3352ceb47ea0cb3958aef28 /candle-pyo3
parente8e33752f4562d69cbef3de61d02676da112dfb8 (diff)
downloadcandle-ad796eb4be9877712c0034d291a082cee1fd2dec.tar.gz
candle-ad796eb4be9877712c0034d291a082cee1fd2dec.tar.bz2
candle-ad796eb4be9877712c0034d291a082cee1fd2dec.zip
More quantized llama in python. (#716)
* More quantized llama in python. * Expose a couple more functions. * Apply the last layer. * Use the vocab from the ggml files.
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/quant-llama.py19
-rw-r--r--candle-pyo3/src/lib.rs56
2 files changed, 64 insertions, 11 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
index a3638855..092c1faa 100644
--- a/candle-pyo3/quant-llama.py
+++ b/candle-pyo3/quant-llama.py
@@ -117,7 +117,6 @@ def precompute_freqs_cis(hparams, freq_base):
idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
m = idx_theta.matmul(theta.unsqueeze(0))
- print(m.shape)
return (m.cos(), m.sin())
class QuantizedLlama:
@@ -143,28 +142,36 @@ class QuantizedLlama:
for layer in self.layers:
x = layer(x, mask, index_pos)
+ x = self.norm(x)
+ x = x.narrow(1, -1, 1).squeeze(1)
+ x = self.output.matmul_t(x)
return x
def main():
if len(sys.argv) < 2:
raise ValueError("missing weight file argument")
filename = sys.argv[1]
+ print(f"reading model file {filename}")
if filename.endswith("gguf"):
all_tensors = candle.load_gguf(sys.argv[1])
hparams = None
+ vocab = None
else:
- all_tensors, hparams = candle.load_ggml(sys.argv[1])
+ all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
+ print("model built, starting inference")
tokens = [1]
- for token_idx in range(1):
- print(tokens)
+ for token_idx in range(500):
last_token = tokens[-1]
lt = candle.tensor([last_token]).unsqueeze(0)
logits = model(lt, len(tokens))
- print(logits)
- next_token = "TODO: sample"
+ # Greedy sampling for now
+ # pr = candle.nn.softmax(logits, -1)
+ m = logits.get(0).argmax_keepdim(-1)
+ next_token = m.values()[0]
+ print(vocab[next_token], end='', flush=True)
tokens.append(next_token)
if __name__ == '__main__':
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 43d99c25..5e6f48ea 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -145,6 +145,22 @@ pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
pydtype!(f64, |v| v);
+fn actual_index(t: &Tensor, dim: usize, index: i64) -> ::candle::Result<usize> {
+ let dim = t.dim(dim)?;
+ if 0 <= index {
+ let index = index as usize;
+ if dim <= index {
+ ::candle::bail!("index {index} is too large for tensor dimension {dim}")
+ }
+ Ok(index)
+ } else {
+ if (dim as i64) < -index {
+ ::candle::bail!("index {index} is too low for tensor dimension {dim}")
+ }
+ Ok((dim as i64 + index) as usize)
+ }
+}
+
fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result<usize> {
let rank = t.rank();
if 0 <= dim {
@@ -409,7 +425,8 @@ impl PyTensor {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
}
- fn squeeze(&self, dim: usize) -> PyResult<Self> {
+ fn squeeze(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
}
@@ -417,7 +434,8 @@ impl PyTensor {
Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
}
- fn get(&self, index: usize) -> PyResult<Self> {
+ fn get(&self, index: i64) -> PyResult<Self> {
+ let index = actual_index(self, 0, index).map_err(wrap_err)?;
Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
}
@@ -425,11 +443,32 @@ impl PyTensor {
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
}
- fn narrow(&self, dim: i64, start: usize, len: usize) -> PyResult<Self> {
+ fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ let start = actual_index(self, dim, start).map_err(wrap_err)?;
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
}
+ fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?))
+ }
+
+ fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?))
+ }
+
+ fn max_keepdim(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?))
+ }
+
+ fn min_keepdim(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?))
+ }
+
fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> {
let dims = if let Ok(dim) = dims.extract::<usize>(py) {
vec![dim]
@@ -661,7 +700,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
}
#[pyfunction]
-fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
+fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
let tensors = ggml
@@ -681,7 +720,14 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
("ftype", ggml.hparams.ftype),
];
let hparams = hparams.into_py_dict(py).to_object(py);
- Ok((tensors, hparams))
+ let vocab = ggml
+ .vocab
+ .token_score_pairs
+ .iter()
+ .map(|(bytes, _)| String::from_utf8_lossy(bytes.as_slice()).to_string())
+ .collect::<Vec<String>>()
+ .to_object(py);
+ Ok((tensors, hparams, vocab))
}
#[pyfunction]