From 000487c36fc6da3a8d645fffbd8c24e23bcdeed1 Mon Sep 17 00:00:00 2001
From: Laurent Mazare <laurent.mazare@gmail.com>
Date: Mon, 4 Sep 2023 21:32:14 +0200
Subject: Add a python function to save as safetensors. (#740)

---
 candle-pyo3/src/lib.rs | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

(limited to 'candle-pyo3')

diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index f71970d5..eddc0fda 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1,5 +1,4 @@
 #![allow(clippy::redundant_closure_call)]
-// TODO: Handle negative dimension indexes.
 use pyo3::exceptions::{PyTypeError, PyValueError};
 use pyo3::prelude::*;
 use pyo3::types::{IntoPyDict, PyTuple};
@@ -714,6 +713,18 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
     Ok(res.into_py_dict(py).to_object(py))
 }
 
+#[pyfunction]
+fn save_safetensors(
+    path: &str,
+    tensors: std::collections::HashMap<String, PyTensor>,
+) -> PyResult<()> {
+    let tensors = tensors
+        .into_iter()
+        .map(|(s, t)| (s, t.0))
+        .collect::<std::collections::HashMap<_, _>>();
+    ::candle::safetensors::save(&tensors, path).map_err(wrap_err)
+}
+
 #[pyfunction]
 fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
     let mut file = std::fs::File::open(path)?;
@@ -867,6 +878,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     m.add_function(wrap_pyfunction!(rand, m)?)?;
     m.add_function(wrap_pyfunction!(randn, m)?)?;
     m.add_function(wrap_pyfunction!(tensor, m)?)?;
+    m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
     m.add_function(wrap_pyfunction!(stack, m)?)?;
     m.add_function(wrap_pyfunction!(zeros, m)?)?;
     Ok(())
-- 
cgit v1.2.3