summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-02 12:26:05 +0200
committerGitHub <noreply@github.com>2023-09-02 11:26:05 +0100
commite8e33752f4562d69cbef3de61d02676da112dfb8 (patch)
tree9b8b8587aa1bb60eef0ac3a44ef6f4a0f86916cc /candle-pyo3
parentdabaa479b966296faad294c40b69d321d51ee4df (diff)
downloadcandle-e8e33752f4562d69cbef3de61d02676da112dfb8.tar.gz
candle-e8e33752f4562d69cbef3de61d02676da112dfb8.tar.bz2
candle-e8e33752f4562d69cbef3de61d02676da112dfb8.zip
Sketch a quantized llama using the pyo3 api. (#715)
* Sketch a quantized llama using the pyo3 api. * Add more ops. * Expose a few more functions to use in the quantized model. * Rope embeddings. * Get the forward pass to work.
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/Cargo.toml1
-rw-r--r--candle-pyo3/quant-llama.py171
-rw-r--r--candle-pyo3/src/lib.rs111
3 files changed, 277 insertions, 6 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index 60272c9b..97631b0a 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -16,6 +16,7 @@ doc = false
[dependencies]
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
+candle-nn = { path = "../candle-nn", version = "0.2.1" }
half = { workspace = true }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
new file mode 100644
index 00000000..a3638855
--- /dev/null
+++ b/candle-pyo3/quant-llama.py
@@ -0,0 +1,171 @@
+# This example shows how the candle Python api can be used to replicate llama.cpp.
+import os
+import sys
+
+# The "import candle" statement below works if there is a "candle.so" file in sys.path.
+# Here we check for shared libraries that can be used in the build directory.
+BUILD_DIR = "./target/release-with-debug"
+so_file = BUILD_DIR + "/candle.so"
+if os.path.islink(so_file): os.remove(so_file)
+for lib_file in ["libcandle.dylib", "libcandle.so"]:
+ lib_file_ = BUILD_DIR + "/" + lib_file
+ if os.path.isfile(lib_file_):
+ os.symlink(lib_file, so_file)
+ sys.path.insert(0, BUILD_DIR)
+ break
+
+import candle
+
+MAX_SEQ_LEN = 4096
+
+def masked_fill(on_false, mask, on_true):
+ shape = mask.shape
+ on_true = candle.tensor(on_true).broadcast_as(shape)
+ return mask.where_cond(on_true, on_false)
+
+class RmsNorm:
+ def __init__(self, qtensor):
+ self.weight = qtensor.dequantize()
+
+ def __call__(self, x):
+ b_size, seq_len, hidden_size = x.shape
+ norm_x = x.sqr().sum_keepdim(2) / hidden_size
+ x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
+ return x_normed.broadcast_mul(self.weight)
+
+class QuantizedLayer:
+ def __init__(self, layer_idx, hparams, all_tensors, cos_sin):
+ p = f"layers.{layer_idx}"
+ self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
+ self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
+ self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
+ self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
+ self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
+ self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
+ self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
+ self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
+ self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
+
+ self.n_head = hparams["n_head"]
+ self.n_kv_head = self.n_head
+ self.head_dim = hparams["n_embd"] // self.n_head
+
+ self.kv_cache = None
+ self.cos = cos_sin[0]
+ self.sin = cos_sin[1]
+
+ def __call__(self, x, mask, index_pos):
+ residual = x
+ x = self.attn_norm(x)
+ attn = self.forward_attn(x, mask, index_pos)
+ x = attn + residual
+
+ residual = x
+ x = self.ffn_norm(x)
+ w1 = self.ffw1.matmul_t(x)
+ w3 = self.ffw3.matmul_t(x)
+ mlp = self.ffw2.matmul_t(candle.nn.silu(w1) * w3)
+
+ return mlp + residual
+
+ def forward_attn(self, x, mask, index_pos):
+ b_size, seq_len, n_embd = x.shape
+ q = self.attention_wq.matmul_t(x)
+ k = self.attention_wk.matmul_t(x)
+ v = self.attention_wv.matmul_t(x)
+
+ q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
+ k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
+ v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
+
+ q = self.apply_rotary_emb(q, index_pos)
+ k = self.apply_rotary_emb(k, index_pos)
+
+ if self.kv_cache is not None and index_pos > 0:
+ prev_k, prev_v = self.kv_cache
+ k = candle.cat([prev_k, k], 2).contiguous()
+ v = candle.cat([prev_v, v], 2).contiguous()
+
+ self.kv_cache = (k, v)
+
+ # TODO: maybe repeat k/v here if we start supporting MQA.
+
+ att = q.matmul(k.t()) / self.head_dim**0.5
+ mask = mask.broadcast_as(att.shape)
+ att = masked_fill(att, mask, float("-inf"))
+ att = candle.nn.softmax(att, -1)
+ y = att.matmul(v.contiguous())
+ y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
+ return self.attention_wo.matmul_t(y)
+
+ def apply_rotary_emb(self, x, index_pos):
+ (b_size, n_head, seq_len, n_embd) = x.shape
+ cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
+ sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
+ x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2))
+ x0 = x.narrow(-1, 0, 1)
+ x1 = x.narrow(-1, 1, 1)
+ y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
+ y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
+ rope = candle.cat([y0, y1], -1)
+ return rope.flatten_from(-2)
+
+def precompute_freqs_cis(hparams, freq_base):
+ head_dim = hparams["n_embd"] // hparams["n_head"]
+ theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
+ theta = candle.tensor(theta)
+ idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
+ idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
+ m = idx_theta.matmul(theta.unsqueeze(0))
+ print(m.shape)
+ return (m.cos(), m.sin())
+
+class QuantizedLlama:
+ def __init__(self, hparams, all_tensors):
+ self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
+ self.norm = RmsNorm(all_tensors["norm.weight"])
+ self.output = all_tensors["output.weight"]
+ self.layers = []
+ cos_sin = precompute_freqs_cis(hparams, 10000.)
+ for layer_idx in range(hparams["n_layer"]):
+ layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
+ self.layers.append(layer)
+
+ def __call__(self, token, index_pos):
+ b_size, seq_len = token.shape
+ vocab_size, hidden_size = self.tok_embeddings.shape
+ token = token.reshape((b_size * seq_len,))
+ x = self.tok_embeddings.index_select(token, 0)
+ x = x.reshape((b_size, seq_len, hidden_size))
+
+ mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
+ mask = candle.tensor(mask).reshape((seq_len, seq_len))
+
+ for layer in self.layers:
+ x = layer(x, mask, index_pos)
+ return x
+
+def main():
+ if len(sys.argv) < 2:
+ raise ValueError("missing weight file argument")
+ filename = sys.argv[1]
+ if filename.endswith("gguf"):
+ all_tensors = candle.load_gguf(sys.argv[1])
+ hparams = None
+ else:
+ all_tensors, hparams = candle.load_ggml(sys.argv[1])
+ print(hparams)
+ model = QuantizedLlama(hparams, all_tensors)
+
+ tokens = [1]
+ for token_idx in range(1):
+ print(tokens)
+ last_token = tokens[-1]
+ lt = candle.tensor([last_token]).unsqueeze(0)
+ logits = model(lt, len(tokens))
+ print(logits)
+ next_token = "TODO: sample"
+ tokens.append(next_token)
+
+if __name__ == '__main__':
+ main()
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 2673d843..43d99c25 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -145,6 +145,22 @@ pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
pydtype!(f64, |v| v);
+fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result<usize> {
+ let rank = t.rank();
+ if 0 <= dim {
+ let dim = dim as usize;
+ if rank <= dim {
+ ::candle::bail!("dimension index {dim} is too large for tensor rank {rank}")
+ }
+ Ok(dim)
+ } else {
+ if (rank as i64) < -dim {
+ ::candle::bail!("dimension index {dim} is too low for tensor rank {rank}")
+ }
+ Ok((rank as i64 + dim) as usize)
+ }
+}
+
// TODO: Something similar to this should probably be a part of candle core.
trait MapDType {
type Output;
@@ -182,7 +198,10 @@ impl PyTensor {
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
} else {
- Err(PyTypeError::new_err("incorrect type for tensor"))?
+ let ty = vs.as_ref(py).get_type();
+ Err(PyTypeError::new_err(format!(
+ "incorrect type {ty} for tensor"
+ )))?
};
Ok(Self(tensor))
}
@@ -295,10 +314,31 @@ impl PyTensor {
Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
}
+ fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
+ }
+
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
}
+ fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
+ Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
+ }
+
+ fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
+ Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
+ }
+
+ fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
+ Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
+ }
+
+ fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
+ Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
+ }
+
fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
Ok(PyTensor(
self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
@@ -346,6 +386,17 @@ impl PyTensor {
Ok(Self(tensor))
}
+ fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
+ let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
+ (&self.0 / &rhs.0).map_err(wrap_err)?
+ } else if let Ok(rhs) = rhs.extract::<f64>() {
+ (&self.0 / rhs).map_err(wrap_err)?
+ } else {
+ Err(PyTypeError::new_err("unsupported rhs for div"))?
+ };
+ Ok(Self(tensor))
+ }
+
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
}
@@ -374,7 +425,8 @@ impl PyTensor {
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
}
- fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult<Self> {
+ fn narrow(&self, dim: i64, start: usize, len: usize) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
}
@@ -400,6 +452,16 @@ impl PyTensor {
Ok(PyTensor(mean))
}
+ fn flatten_from(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
+ }
+
+ fn flatten_to(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
+ }
+
fn flatten_all(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
}
@@ -467,7 +529,11 @@ impl PyTensor {
/// Concatenate the tensors across one axis.
#[pyfunction]
-fn cat(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
+fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
+ if tensors.is_empty() {
+ return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
+ }
+ let dim = actual_dim(&tensors[0], dim).map_err(wrap_err)?;
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?;
Ok(PyTensor(tensor))
@@ -595,16 +661,27 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
}
#[pyfunction]
-fn load_ggml(path: &str, py: Python<'_>) -> PyResult<PyObject> {
+fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
- let res = ggml
+ let tensors = ggml
.tensors
.into_iter()
.map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
- Ok(res.into_py_dict(py).to_object(py))
+ let tensors = tensors.into_py_dict(py).to_object(py);
+ let hparams = [
+ ("n_vocab", ggml.hparams.n_vocab),
+ ("n_embd", ggml.hparams.n_embd),
+ ("n_mult", ggml.hparams.n_mult),
+ ("n_head", ggml.hparams.n_head),
+ ("n_layer", ggml.hparams.n_layer),
+ ("n_rot", ggml.hparams.n_rot),
+ ("ftype", ggml.hparams.ftype),
+ ];
+ let hparams = hparams.into_py_dict(py).to_object(py);
+ Ok((tensors, hparams))
}
#[pyfunction]
@@ -651,11 +728,33 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Ok(())
}
+#[pyfunction]
+fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> {
+ let dim = actual_dim(&t, dim).map_err(wrap_err)?;
+ let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?;
+ Ok(PyTensor(sm))
+}
+
+#[pyfunction]
+fn silu(t: PyTensor) -> PyResult<PyTensor> {
+ let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?;
+ Ok(PyTensor(s))
+}
+
+fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
+ m.add_function(wrap_pyfunction!(silu, m)?)?;
+ m.add_function(wrap_pyfunction!(softmax, m)?)?;
+ Ok(())
+}
+
#[pymodule]
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
candle_utils(py, utils)?;
m.add_submodule(utils)?;
+ let nn = PyModule::new(py, "nn")?;
+ candle_nn_m(py, nn)?;
+ m.add_submodule(nn)?;
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;
m.add_class::<PyDType>()?;