diff options
author | laurent <laurent.mazare@gmail.com> | 2023-07-01 21:27:35 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-07-01 21:27:35 +0100 |
commit | fbbde5b02cc4a711aed609f887d1705d67b2fd20 (patch) | |
tree | 4b3749fc2260d11562cdcda7766366603e5d8233 /candle-pyo3/src/lib.rs | |
parent | 42d1a52d01f5f10e3a04257cb4612225f08e1321 (diff) | |
download | candle-fbbde5b02cc4a711aed609f887d1705d67b2fd20.tar.gz candle-fbbde5b02cc4a711aed609f887d1705d67b2fd20.tar.bz2 candle-fbbde5b02cc4a711aed609f887d1705d67b2fd20.zip |
Add some binary operators.
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r-- | candle-pyo3/src/lib.rs | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index e1ce7f97..35de86c8 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,4 +1,4 @@ -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use ::candle::{Device::Cpu, Tensor}; @@ -7,6 +7,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::<PyValueError, _>(format!("{err:?}")) } +#[derive(Clone)] #[pyclass(name = "Tensor")] struct PyTensor(Tensor); @@ -42,6 +43,21 @@ impl PyTensor { fn __str__(&self) -> String { self.__repr__() } + + fn __add__(&self, rhs: &PyAny) -> PyResult<Self> { + let tensor = if let Ok(rhs) = rhs.extract::<Self>() { + (&self.0 + &rhs.0).map_err(wrap_err)? + } else if let Ok(rhs) = rhs.extract::<f64>() { + (&self.0 + rhs).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("unsupported for add"))? + }; + Ok(Self(tensor)) + } + + fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> { + self.__add__(rhs) + } } #[pyfunction] |