diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-04 21:32:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-04 20:32:14 +0100 |
commit | 000487c36fc6da3a8d645fffbd8c24e23bcdeed1 (patch) | |
tree | c7a862aa62b99dfb6a384149939ed44ca4e3b090 /candle-pyo3 | |
parent | ab0d9fbdd1db6a586c0cd6ca9ee1a31203db3684 (diff) | |
download | candle-000487c36fc6da3a8d645fffbd8c24e23bcdeed1.tar.gz candle-000487c36fc6da3a8d645fffbd8c24e23bcdeed1.tar.bz2 candle-000487c36fc6da3a8d645fffbd8c24e23bcdeed1.zip |
Add a python function to save as safetensors. (#740)
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/src/lib.rs | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index f71970d5..eddc0fda 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,5 +1,4 @@ #![allow(clippy::redundant_closure_call)] -// TODO: Handle negative dimension indexes. use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyTuple}; @@ -715,6 +714,18 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> { } #[pyfunction] +fn save_safetensors( + path: &str, + tensors: std::collections::HashMap<String, PyTensor>, +) -> PyResult<()> { + let tensors = tensors + .into_iter() + .map(|(s, t)| (s, t.0)) + .collect::<std::collections::HashMap<_, _>>(); + ::candle::safetensors::save(&tensors, path).map_err(wrap_err) +} + +#[pyfunction] fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; @@ -867,6 +878,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(rand, m)?)?; m.add_function(wrap_pyfunction!(randn, m)?)?; m.add_function(wrap_pyfunction!(tensor, m)?)?; + m.add_function(wrap_pyfunction!(save_safetensors, m)?)?; m.add_function(wrap_pyfunction!(stack, m)?)?; m.add_function(wrap_pyfunction!(zeros, m)?)?; Ok(()) |