From b355ab4e2e52b077e71aac46c286fbce033f36d6 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Tue, 17 Oct 2023 11:57:12 +0200 Subject: Always broadcast magic methods (#1101) --- candle-pyo3/src/lib.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'candle-pyo3/src') diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 02db05e5..55b20308 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -536,7 +536,7 @@ impl PyTensor { /// &RETURNS&: Tensor fn __add__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { - (&self.0 + &rhs.0).map_err(wrap_err)? + self.0.broadcast_add(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { (&self.0 + rhs).map_err(wrap_err)? } else { @@ -553,7 +553,7 @@ impl PyTensor { /// &RETURNS&: Tensor fn __mul__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { - (&self.0 * &rhs.0).map_err(wrap_err)? + self.0.broadcast_mul(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { (&self.0 * rhs).map_err(wrap_err)? } else { @@ -570,7 +570,7 @@ impl PyTensor { /// &RETURNS&: Tensor fn __sub__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { - (&self.0 - &rhs.0).map_err(wrap_err)? + self.0.broadcast_sub(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { (&self.0 - rhs).map_err(wrap_err)? } else { @@ -583,7 +583,7 @@ impl PyTensor { /// &RETURNS&: Tensor fn __truediv__(&self, rhs: &PyAny) -> PyResult { let tensor = if let Ok(rhs) = rhs.extract::() { - (&self.0 / &rhs.0).map_err(wrap_err)? + self.0.broadcast_div(&rhs.0).map_err(wrap_err)? } else if let Ok(rhs) = rhs.extract::() { (&self.0 / rhs).map_err(wrap_err)? } else { -- cgit v1.2.3