diff options
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r-- | candle-pyo3/src/lib.rs | 41 |
1 files changed, 37 insertions, 4 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 79f86479..f71970d5 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -746,10 +746,35 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje } #[pyfunction] -fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> { +fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { + use ::candle::quantized::gguf_file; + fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> { + let v: PyObject = match v { + gguf_file::Value::U8(x) => x.into_py(py), + gguf_file::Value::I8(x) => x.into_py(py), + gguf_file::Value::U16(x) => x.into_py(py), + gguf_file::Value::I16(x) => x.into_py(py), + gguf_file::Value::U32(x) => x.into_py(py), + gguf_file::Value::I32(x) => x.into_py(py), + gguf_file::Value::U64(x) => x.into_py(py), + gguf_file::Value::I64(x) => x.into_py(py), + gguf_file::Value::F32(x) => x.into_py(py), + gguf_file::Value::F64(x) => x.into_py(py), + gguf_file::Value::Bool(x) => x.into_py(py), + gguf_file::Value::String(x) => x.into_py(py), + gguf_file::Value::Array(x) => { + let list = pyo3::types::PyList::empty(py); + for elem in x.iter() { + list.append(gguf_value_to_pyobject(elem, py)?)?; + } + list.into() + } + }; + Ok(v) + } let mut file = std::fs::File::open(path)?; - let gguf = ::candle::quantized::gguf_file::Content::read(&mut file).map_err(wrap_err)?; - let res = gguf + let gguf = gguf_file::Content::read(&mut file).map_err(wrap_err)?; + let tensors = gguf .tensor_infos .keys() .map(|key| { @@ -758,7 +783,15 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> { }) .collect::<::candle::Result<Vec<_>>>() .map_err(wrap_err)?; - Ok(res.into_py_dict(py).to_object(py)) + let tensors = tensors.into_py_dict(py).to_object(py); + let metadata = gguf + .metadata + .iter() + .map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?))) + .collect::<PyResult<Vec<_>>>()? + .into_py_dict(py) + .to_object(py); + Ok((tensors, metadata)) } #[pyfunction] |