diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-04 08:07:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-04 07:07:00 +0100 |
commit | 20512ba408f9840828e902b7dd824be5a0969feb (patch) | |
tree | d525df2e607b60e130a32bd8bcf62a01fb4be47f /candle-pyo3 | |
parent | 9c61b0fc9b9062b347c176b5f0f86b97b6804a1b (diff) | |
download | candle-20512ba408f9840828e902b7dd824be5a0969feb.tar.gz candle-20512ba408f9840828e902b7dd824be5a0969feb.tar.bz2 candle-20512ba408f9840828e902b7dd824be5a0969feb.zip |
Return the metadata in the gguf pyo3 bindings. (#729)
* Return the metadata in the gguf pyo3 bindings.
* Read the metadata in the quantized llama example.
* Get inference to work on gguf files.
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/quant-llama.py | 39 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 41 |
2 files changed, 72 insertions, 8 deletions
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py index 7d74c25e..0f7a51c6 100644 --- a/candle-pyo3/quant-llama.py +++ b/candle-pyo3/quant-llama.py @@ -111,7 +111,8 @@ class QuantizedLlama: self.norm = RmsNorm(all_tensors["norm.weight"]) self.output = all_tensors["output.weight"] self.layers = [] - cos_sin = precompute_freqs_cis(hparams, 10000.) + rope_freq = hparams.get("rope_freq", 10000.) + cos_sin = precompute_freqs_cis(hparams, rope_freq) for layer_idx in range(hparams["n_layer"]): layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin) self.layers.append(layer) @@ -133,15 +134,45 @@ class QuantizedLlama: x = self.output.matmul_t(x) return x +def gguf_rename(tensor_name): + if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight' + if tensor_name == 'output_norm.weight': return 'norm.weight' + tensor_name = tensor_name.replace('blk.', 'layers.') + tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.') + tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.') + tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.') + tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.') + tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.') + tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.') + tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.') + tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.') + return tensor_name + def main(): if len(sys.argv) < 2: raise ValueError("missing weight file argument") filename = sys.argv[1] print(f"reading model file {filename}") if filename.endswith("gguf"): - all_tensors = candle.load_gguf(sys.argv[1]) - hparams = None - vocab = None + all_tensors, metadata = candle.load_gguf(sys.argv[1]) + vocab = metadata["tokenizer.ggml.tokens"] + for i, v in enumerate(vocab): + vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ') + hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")} + print(hparams) + hparams = { + 'n_vocab': len(vocab), + 'n_embd': metadata['llama.embedding_length'], + 'n_mult': 256, + 'n_head': metadata['llama.attention.head_count'], + 'n_head_kv': metadata['llama.attention.head_count_kv'], + 'n_layer': metadata['llama.block_count'], + 'n_rot': metadata['llama.rope.dimension_count'], + 'rope_freq': metadata['llama.rope.freq_base'], + 'ftype': metadata['general.file_type'], + } + all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() } + else: all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1]) print(hparams) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 79f86479..f71970d5 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -746,10 +746,35 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje } #[pyfunction] -fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> { +fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { + use ::candle::quantized::gguf_file; + fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> { + let v: PyObject = match v { + gguf_file::Value::U8(x) => x.into_py(py), + gguf_file::Value::I8(x) => x.into_py(py), + gguf_file::Value::U16(x) => x.into_py(py), + gguf_file::Value::I16(x) => x.into_py(py), + gguf_file::Value::U32(x) => x.into_py(py), + gguf_file::Value::I32(x) => x.into_py(py), + gguf_file::Value::U64(x) => x.into_py(py), + gguf_file::Value::I64(x) => x.into_py(py), + gguf_file::Value::F32(x) => x.into_py(py), + gguf_file::Value::F64(x) => x.into_py(py), + gguf_file::Value::Bool(x) => x.into_py(py), + gguf_file::Value::String(x) => x.into_py(py), + gguf_file::Value::Array(x) => { + let list = pyo3::types::PyList::empty(py); + for elem in x.iter() { + list.append(gguf_value_to_pyobject(elem, py)?)?; + } + list.into() + } + }; + Ok(v) + } let mut file = std::fs::File::open(path)?; - let gguf = ::candle::quantized::gguf_file::Content::read(&mut file).map_err(wrap_err)?; - let res = gguf + let gguf = gguf_file::Content::read(&mut file).map_err(wrap_err)?; + let tensors = gguf .tensor_infos .keys() .map(|key| { @@ -758,7 +783,15 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> { }) .collect::<::candle::Result<Vec<_>>>() .map_err(wrap_err)?; - Ok(res.into_py_dict(py).to_object(py)) + let tensors = tensors.into_py_dict(py).to_object(py); + let metadata = gguf + .metadata + .iter() + .map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?))) + .collect::<PyResult<Vec<_>>>()? + .into_py_dict(py) + .to_object(py); + Ok((tensors, metadata)) } #[pyfunction] |