summaryrefslogtreecommitdiff
path: root/candle-pyo3/src/lib.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-07-01 21:27:35 +0100
committerlaurent <laurent.mazare@gmail.com>2023-07-01 21:27:35 +0100
commitfbbde5b02cc4a711aed609f887d1705d67b2fd20 (patch)
tree4b3749fc2260d11562cdcda7766366603e5d8233 /candle-pyo3/src/lib.rs
parent42d1a52d01f5f10e3a04257cb4612225f08e1321 (diff)
downloadcandle-fbbde5b02cc4a711aed609f887d1705d67b2fd20.tar.gz
candle-fbbde5b02cc4a711aed609f887d1705d67b2fd20.tar.bz2
candle-fbbde5b02cc4a711aed609f887d1705d67b2fd20.zip
Add some binary operators.
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r--candle-pyo3/src/lib.rs18
1 files changed, 17 insertions, 1 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index e1ce7f97..35de86c8 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1,4 +1,4 @@
-use pyo3::exceptions::PyValueError;
+use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use ::candle::{Device::Cpu, Tensor};
@@ -7,6 +7,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
+#[derive(Clone)]
#[pyclass(name = "Tensor")]
struct PyTensor(Tensor);
@@ -42,6 +43,21 @@ impl PyTensor {
fn __str__(&self) -> String {
self.__repr__()
}
+
+ fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
+ let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
+ (&self.0 + &rhs.0).map_err(wrap_err)?
+ } else if let Ok(rhs) = rhs.extract::<f64>() {
+ (&self.0 + rhs).map_err(wrap_err)?
+ } else {
+ Err(PyTypeError::new_err("unsupported for add"))?
+ };
+ Ok(Self(tensor))
+ }
+
+ fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
+ self.__add__(rhs)
+ }
}
#[pyfunction]