diff options
author | laurent <laurent.mazare@gmail.com> | 2023-07-02 20:15:50 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-07-02 20:15:50 +0100 |
commit | bdb257ceabcbec1a3401b9f02c817d4f42c46a1e (patch) | |
tree | 6fb82508ce85011c8f1f68fb784ef496e5e35ad2 | |
parent | 78871ffe38a9ae0b6e4a905ab7d0329b7f3567c3 (diff) | |
download | candle-bdb257ceabcbec1a3401b9f02c817d4f42c46a1e.tar.gz candle-bdb257ceabcbec1a3401b9f02c817d4f42c46a1e.tar.bz2 candle-bdb257ceabcbec1a3401b9f02c817d4f42c46a1e.zip |
Add the tensor function.
-rw-r--r-- | candle-pyo3/src/lib.rs | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 7da91b3f..d5d472d5 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -309,10 +309,16 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> { Ok(PyTensor(tensor)) } +#[pyfunction] +fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> { + PyTensor::new(py, vs) +} + #[pymodule] fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::<PyTensor>()?; m.add_function(wrap_pyfunction!(cat, m)?)?; + m.add_function(wrap_pyfunction!(tensor, m)?)?; m.add_function(wrap_pyfunction!(stack, m)?)?; Ok(()) } |