summaryrefslogtreecommitdiff
path: root/candle-pyo3/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r--candle-pyo3/src/lib.rs41
1 files changed, 37 insertions, 4 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 79f86479..f71970d5 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -746,10 +746,35 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
}
#[pyfunction]
-fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> {
+fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
+ use ::candle::quantized::gguf_file;
+ fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
+ let v: PyObject = match v {
+ gguf_file::Value::U8(x) => x.into_py(py),
+ gguf_file::Value::I8(x) => x.into_py(py),
+ gguf_file::Value::U16(x) => x.into_py(py),
+ gguf_file::Value::I16(x) => x.into_py(py),
+ gguf_file::Value::U32(x) => x.into_py(py),
+ gguf_file::Value::I32(x) => x.into_py(py),
+ gguf_file::Value::U64(x) => x.into_py(py),
+ gguf_file::Value::I64(x) => x.into_py(py),
+ gguf_file::Value::F32(x) => x.into_py(py),
+ gguf_file::Value::F64(x) => x.into_py(py),
+ gguf_file::Value::Bool(x) => x.into_py(py),
+ gguf_file::Value::String(x) => x.into_py(py),
+ gguf_file::Value::Array(x) => {
+ let list = pyo3::types::PyList::empty(py);
+ for elem in x.iter() {
+ list.append(gguf_value_to_pyobject(elem, py)?)?;
+ }
+ list.into()
+ }
+ };
+ Ok(v)
+ }
let mut file = std::fs::File::open(path)?;
- let gguf = ::candle::quantized::gguf_file::Content::read(&mut file).map_err(wrap_err)?;
- let res = gguf
+ let gguf = gguf_file::Content::read(&mut file).map_err(wrap_err)?;
+ let tensors = gguf
.tensor_infos
.keys()
.map(|key| {
@@ -758,7 +783,15 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<PyObject> {
})
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
- Ok(res.into_py_dict(py).to_object(py))
+ let tensors = tensors.into_py_dict(py).to_object(py);
+ let metadata = gguf
+ .metadata
+ .iter()
+ .map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?)))
+ .collect::<PyResult<Vec<_>>>()?
+ .into_py_dict(py)
+ .to_object(py);
+ Ok((tensors, metadata))
}
#[pyfunction]