diff options
-rw-r--r-- | candle-pyo3/src/lib.rs | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 7da91b3f..d5d472d5 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -309,10 +309,16 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> { Ok(PyTensor(tensor)) } +#[pyfunction] +fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> { + PyTensor::new(py, vs) +} + #[pymodule] fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::<PyTensor>()?; m.add_function(wrap_pyfunction!(cat, m)?)?; + m.add_function(wrap_pyfunction!(tensor, m)?)?; m.add_function(wrap_pyfunction!(stack, m)?)?; Ok(()) } |