diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 6bb3d740..8ad9322b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -4,6 +4,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{ BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp, }; +use crate::scalar::TensorOrScalar; use crate::shape::{Dim, Dims}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -776,8 +777,15 @@ impl Tensor { /// comparison operation is specified by the `op` argument. /// /// The returned tensor has the same shape as the original tensors and uses `u8` elements. - pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> { - let shape = self.same_shape_binary_op(rhs, "cmp")?; + pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> { + let rhs = match rhs.to_tensor_scalar()? { + crate::scalar::TensorScalar::Tensor(rhs) => rhs, + crate::scalar::TensorScalar::Scalar(rhs) => rhs + .to_dtype(self.dtype())? + .to_device(self.device())? + .broadcast_as(self.shape())?, + }; + let shape = self.same_shape_binary_op(&rhs, "cmp")?; let storage = self .storage() .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?; @@ -786,36 +794,36 @@ impl Tensor { } /// Element-wise equality. - pub fn eq(&self, rhs: &Self) -> Result<Self> { + pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Eq) } /// Element-wise non-equality. - pub fn ne(&self, rhs: &Self) -> Result<Self> { + pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Ne) } /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self < /// rhs` and 0 otherwise. - pub fn lt(&self, rhs: &Self) -> Result<Self> { + pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Lt) } /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self > /// rhs` and 0 otherwise. - pub fn gt(&self, rhs: &Self) -> Result<Self> { + pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Gt) } /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >= /// rhs` and 0 otherwise. - pub fn ge(&self, rhs: &Self) -> Result<Self> { + pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Ge) } /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <= /// rhs` and 0 otherwise. - pub fn le(&self, rhs: &Self) -> Result<Self> { + pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> { self.cmp(rhs, CmpOp::Le) } |