summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-04 21:32:14 +0200
committerGitHub <noreply@github.com>2023-09-04 20:32:14 +0100
commit000487c36fc6da3a8d645fffbd8c24e23bcdeed1 (patch)
treec7a862aa62b99dfb6a384149939ed44ca4e3b090 /candle-pyo3
parentab0d9fbdd1db6a586c0cd6ca9ee1a31203db3684 (diff)
downloadcandle-000487c36fc6da3a8d645fffbd8c24e23bcdeed1.tar.gz
candle-000487c36fc6da3a8d645fffbd8c24e23bcdeed1.tar.bz2
candle-000487c36fc6da3a8d645fffbd8c24e23bcdeed1.zip
Add a python function to save as safetensors. (#740)
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/src/lib.rs14
1 files changed, 13 insertions, 1 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index f71970d5..eddc0fda 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1,5 +1,4 @@
#![allow(clippy::redundant_closure_call)]
-// TODO: Handle negative dimension indexes.
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyTuple};
@@ -715,6 +714,18 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
}
#[pyfunction]
+fn save_safetensors(
+ path: &str,
+ tensors: std::collections::HashMap<String, PyTensor>,
+) -> PyResult<()> {
+ let tensors = tensors
+ .into_iter()
+ .map(|(s, t)| (s, t.0))
+ .collect::<std::collections::HashMap<_, _>>();
+ ::candle::safetensors::save(&tensors, path).map_err(wrap_err)
+}
+
+#[pyfunction]
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
@@ -867,6 +878,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(rand, m)?)?;
m.add_function(wrap_pyfunction!(randn, m)?)?;
m.add_function(wrap_pyfunction!(tensor, m)?)?;
+ m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(stack, m)?)?;
m.add_function(wrap_pyfunction!(zeros, m)?)?;
Ok(())