summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-07-02 20:15:50 +0100
committerlaurent <laurent.mazare@gmail.com>2023-07-02 20:15:50 +0100
commitbdb257ceabcbec1a3401b9f02c817d4f42c46a1e (patch)
tree6fb82508ce85011c8f1f68fb784ef496e5e35ad2
parent78871ffe38a9ae0b6e4a905ab7d0329b7f3567c3 (diff)
downloadcandle-bdb257ceabcbec1a3401b9f02c817d4f42c46a1e.tar.gz
candle-bdb257ceabcbec1a3401b9f02c817d4f42c46a1e.tar.bz2
candle-bdb257ceabcbec1a3401b9f02c817d4f42c46a1e.zip
Add the tensor function.
-rw-r--r--candle-pyo3/src/lib.rs6
1 files changed, 6 insertions, 0 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 7da91b3f..d5d472d5 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -309,10 +309,16 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
Ok(PyTensor(tensor))
}
+#[pyfunction]
+fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
+ PyTensor::new(py, vs)
+}
+
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTensor>()?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
+ m.add_function(wrap_pyfunction!(tensor, m)?)?;
m.add_function(wrap_pyfunction!(stack, m)?)?;
Ok(())
}