summaryrefslogtreecommitdiff
path: root/candle-pyo3/src
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-17 11:57:12 +0200
committerGitHub <noreply@github.com>2023-10-17 10:57:12 +0100
commitb355ab4e2e52b077e71aac46c286fbce033f36d6 (patch)
tree27f32bbcb5e0aa16ed14790bd3f5b37ae26ddf26 /candle-pyo3/src
parent2fe24ac5b172526c25b07674b38075f8da20815f (diff)
downloadcandle-b355ab4e2e52b077e71aac46c286fbce033f36d6.tar.gz
candle-b355ab4e2e52b077e71aac46c286fbce033f36d6.tar.bz2
candle-b355ab4e2e52b077e71aac46c286fbce033f36d6.zip
Always broadcast magic methods (#1101)
Diffstat (limited to 'candle-pyo3/src')
-rw-r--r--candle-pyo3/src/lib.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 02db05e5..55b20308 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -536,7 +536,7 @@ impl PyTensor {
/// &RETURNS&: Tensor
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
- (&self.0 + &rhs.0).map_err(wrap_err)?
+ self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 + rhs).map_err(wrap_err)?
} else {
@@ -553,7 +553,7 @@ impl PyTensor {
/// &RETURNS&: Tensor
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
- (&self.0 * &rhs.0).map_err(wrap_err)?
+ self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 * rhs).map_err(wrap_err)?
} else {
@@ -570,7 +570,7 @@ impl PyTensor {
/// &RETURNS&: Tensor
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
- (&self.0 - &rhs.0).map_err(wrap_err)?
+ self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 - rhs).map_err(wrap_err)?
} else {
@@ -583,7 +583,7 @@ impl PyTensor {
/// &RETURNS&: Tensor
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
- (&self.0 / &rhs.0).map_err(wrap_err)?
+ self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 / rhs).map_err(wrap_err)?
} else {