summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-pyo3/src/lib.rs6
1 files changed, 6 insertions, 0 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 7da91b3f..d5d472d5 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -309,10 +309,16 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
Ok(PyTensor(tensor))
}
+#[pyfunction]
+fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
+ PyTensor::new(py, vs)
+}
+
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTensor>()?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
+ m.add_function(wrap_pyfunction!(tensor, m)?)?;
m.add_function(wrap_pyfunction!(stack, m)?)?;
Ok(())
}