summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs24
1 files changed, 16 insertions, 8 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 6bb3d740..8ad9322b 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -4,6 +4,7 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
+use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@@ -776,8 +777,15 @@ impl Tensor {
/// comparison operation is specified by the `op` argument.
///
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
- pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
- let shape = self.same_shape_binary_op(rhs, "cmp")?;
+ pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
+ let rhs = match rhs.to_tensor_scalar()? {
+ crate::scalar::TensorScalar::Tensor(rhs) => rhs,
+ crate::scalar::TensorScalar::Scalar(rhs) => rhs
+ .to_dtype(self.dtype())?
+ .to_device(self.device())?
+ .broadcast_as(self.shape())?,
+ };
+ let shape = self.same_shape_binary_op(&rhs, "cmp")?;
let storage = self
.storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
@@ -786,36 +794,36 @@ impl Tensor {
}
/// Element-wise equality.
- pub fn eq(&self, rhs: &Self) -> Result<Self> {
+ pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Eq)
}
/// Element-wise non-equality.
- pub fn ne(&self, rhs: &Self) -> Result<Self> {
+ pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ne)
}
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
/// rhs` and 0 otherwise.
- pub fn lt(&self, rhs: &Self) -> Result<Self> {
+ pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Lt)
}
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
/// rhs` and 0 otherwise.
- pub fn gt(&self, rhs: &Self) -> Result<Self> {
+ pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Gt)
}
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
/// rhs` and 0 otherwise.
- pub fn ge(&self, rhs: &Self) -> Result<Self> {
+ pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Ge)
}
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
/// rhs` and 0 otherwise.
- pub fn le(&self, rhs: &Self) -> Result<Self> {
+ pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
self.cmp(rhs, CmpOp::Le)
}