summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-30 16:17:28 +0100
committerGitHub <noreply@github.com>2023-10-30 15:17:28 +0000
commitc05c0a8213eb5518901b2aa87503e8c6b65b9d0f (patch)
tree238c8ce92c7651d2d4e9e712dc2bb6cc1ab27c2b /candle-pyo3
parent969960847ac7fd4959e8718d1355abb1f9f4385d (diff)
downloadcandle-c05c0a8213eb5518901b2aa87503e8c6b65b9d0f.tar.gz
candle-c05c0a8213eb5518901b2aa87503e8c6b65b9d0f.tar.bz2
candle-c05c0a8213eb5518901b2aa87503e8c6b65b9d0f.zip
PyO3: Add `equal` and `__richcmp__` to `candle.Tensor` (#1099)
* add `equal` to tensor * add `__richcmp__` support for tensors and scalars * typo * more typos * Add `abs` + `candle.testing` * remove duplicated `broadcast_shape_binary_op` * `candle.i16` => `candle.i64` * `tensor.nelements` -> `tensor.nelement` * Cleanup `abs`
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/_additional_typing/__init__.py36
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi41
-rw-r--r--candle-pyo3/py_src/candle/testing/__init__.py70
-rw-r--r--candle-pyo3/src/lib.rs73
-rw-r--r--candle-pyo3/tests/bindings/test_testing.py33
-rw-r--r--candle-pyo3/tests/native/test_tensor.py73
6 files changed, 324 insertions, 2 deletions
diff --git a/candle-pyo3/_additional_typing/__init__.py b/candle-pyo3/_additional_typing/__init__.py
index 0d0eec90..7bc17ee1 100644
--- a/candle-pyo3/_additional_typing/__init__.py
+++ b/candle-pyo3/_additional_typing/__init__.py
@@ -53,3 +53,39 @@ class Tensor:
Return a slice of a tensor.
"""
pass
+
+ def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+
+ def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+
+ def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+
+ def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+
+ def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+
+ def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index 35b17680..37b8fe8c 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -124,16 +124,46 @@ class Tensor:
Add a scalar to a tensor or two tensors together.
"""
pass
+ def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+ def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
+ def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+ def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+ def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
+ def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
@@ -159,6 +189,11 @@ class Tensor:
Divide a tensor by a scalar or one tensor by another.
"""
pass
+ def abs(self) -> Tensor:
+ """
+ Performs the `abs` operation on the tensor.
+ """
+ pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
@@ -308,6 +343,12 @@ class Tensor:
ranges from `start` to `start + len`.
"""
pass
+ @property
+ def nelement(self) -> int:
+ """
+ Gets the tensor's element count.
+ """
+ pass
def powf(self, p: float) -> Tensor:
"""
Performs the `pow` operation on the tensor with the given exponent.
diff --git a/candle-pyo3/py_src/candle/testing/__init__.py b/candle-pyo3/py_src/candle/testing/__init__.py
new file mode 100644
index 00000000..240b635f
--- /dev/null
+++ b/candle-pyo3/py_src/candle/testing/__init__.py
@@ -0,0 +1,70 @@
+import candle
+from candle import Tensor
+
+
+_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])
+
+
+def _assert_tensor_metadata(
+ actual: Tensor,
+ expected: Tensor,
+ check_device: bool = True,
+ check_dtype: bool = True,
+ check_layout: bool = True,
+ check_stride: bool = False,
+):
+ if check_device:
+ assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"
+
+ if check_dtype:
+ assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"
+
+ if check_layout:
+ assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"
+
+ if check_stride:
+ assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"
+
+
+def assert_equal(
+ actual: Tensor,
+ expected: Tensor,
+ check_device: bool = True,
+ check_dtype: bool = True,
+ check_layout: bool = True,
+ check_stride: bool = False,
+):
+ """
+ Asserts that two tensors are exact equals.
+ """
+ _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
+ assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"
+
+
+def assert_almost_equal(
+ actual: Tensor,
+ expected: Tensor,
+ rtol=1e-05,
+ atol=1e-08,
+ check_device: bool = True,
+ check_dtype: bool = True,
+ check_layout: bool = True,
+ check_stride: bool = False,
+):
+ """
+ Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.
+
+ Computes: |actual - expected| ≤ atol + rtol x |expected|
+ """
+ _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
+
+ # Secure against overflow of u32 and u8 tensors
+ if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
+ actual = actual.to(candle.i64)
+ expected = expected.to(candle.i64)
+
+ diff = (actual - expected).abs()
+
+ threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)
+
+ assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 41c4577f..ddd58fbe 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1,8 +1,11 @@
#![allow(clippy::redundant_closure_call)]
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
+use pyo3::pyclass::CompareOp;
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
+use std::collections::hash_map::DefaultHasher;
+use std::hash::{Hash, Hasher};
use std::os::raw::c_long;
use std::sync::Arc;
@@ -132,9 +135,10 @@ macro_rules! pydtype {
}
};
}
+
+pydtype!(i64, |v| v);
pydtype!(u8, |v| v);
pydtype!(u32, |v| v);
-pydtype!(i64, |v| v);
pydtype!(f16, f32::from);
pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
@@ -318,6 +322,13 @@ impl PyTensor {
}
#[getter]
+ /// Gets the tensor's element count.
+ /// &RETURNS&: int
+ fn nelement(&self) -> usize {
+ self.0.elem_count()
+ }
+
+ #[getter]
/// Gets the tensor's strides.
/// &RETURNS&: Tuple[int]
fn stride(&self, py: Python<'_>) -> PyObject {
@@ -353,6 +364,12 @@ impl PyTensor {
self.__repr__()
}
+ /// Performs the `abs` operation on the tensor.
+ /// &RETURNS&: Tensor
+ fn abs(&self) -> PyResult<Self> {
+ Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
+ }
+
/// Performs the `sin` operation on the tensor.
/// &RETURNS&: Tensor
fn sin(&self) -> PyResult<Self> {
@@ -670,6 +687,58 @@ impl PyTensor {
};
Ok(Self(tensor))
}
+ /// Rich-compare two tensors.
+ /// &RETURNS&: Tensor
+ fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
+ let compare = |lhs: &Tensor, rhs: &Tensor| {
+ let t = match op {
+ CompareOp::Eq => lhs.eq(rhs),
+ CompareOp::Ne => lhs.ne(rhs),
+ CompareOp::Lt => lhs.lt(rhs),
+ CompareOp::Le => lhs.le(rhs),
+ CompareOp::Gt => lhs.gt(rhs),
+ CompareOp::Ge => lhs.ge(rhs),
+ };
+ Ok(PyTensor(t.map_err(wrap_err)?))
+ };
+ if let Ok(rhs) = rhs.extract::<PyTensor>() {
+ if self.0.shape() == rhs.0.shape() {
+ compare(&self.0, &rhs.0)
+ } else {
+ // We broadcast manually here because `candle.cmp` does not support automatic broadcasting
+ let broadcast_shape = self
+ .0
+ .shape()
+ .broadcast_shape_binary_op(rhs.0.shape(), "cmp")
+ .map_err(wrap_err)?;
+ let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
+ let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
+
+ compare(&broadcasted_lhs, &broadcasted_rhs)
+ }
+ } else if let Ok(rhs) = rhs.extract::<f64>() {
+ let scalar_tensor = Tensor::new(rhs, self.0.device())
+ .map_err(wrap_err)?
+ .to_dtype(self.0.dtype())
+ .map_err(wrap_err)?
+ .broadcast_as(self.0.shape())
+ .map_err(wrap_err)?;
+
+ compare(&self.0, &scalar_tensor)
+ } else {
+ return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
+ }
+ }
+
+ fn __hash__(&self) -> u64 {
+ // we have overridden __richcmp__ => py03 wants us to also override __hash__
+ // we simply hash the address of the tensor
+ let mut hasher = DefaultHasher::new();
+ let pointer = &self.0 as *const Tensor;
+ let address = pointer as usize;
+ address.hash(&mut hasher);
+ hasher.finish()
+ }
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Reshapes the tensor to the given shape.
@@ -1503,7 +1572,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDType>()?;
m.add("u8", PyDType(DType::U8))?;
m.add("u32", PyDType(DType::U32))?;
- m.add("i16", PyDType(DType::I64))?;
+ m.add("i64", PyDType(DType::I64))?;
m.add("bf16", PyDType(DType::BF16))?;
m.add("f16", PyDType(DType::F16))?;
m.add("f32", PyDType(DType::F32))?;
diff --git a/candle-pyo3/tests/bindings/test_testing.py b/candle-pyo3/tests/bindings/test_testing.py
new file mode 100644
index 00000000..db2fd3f7
--- /dev/null
+++ b/candle-pyo3/tests/bindings/test_testing.py
@@ -0,0 +1,33 @@
+import candle
+from candle import Tensor
+from candle.testing import assert_equal, assert_almost_equal
+import pytest
+
+
+@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
+def test_assert_equal_asserts_correctly(dtype: candle.DType):
+ a = Tensor([1, 2, 3]).to(dtype)
+ b = Tensor([1, 2, 3]).to(dtype)
+ assert_equal(a, b)
+
+ with pytest.raises(AssertionError):
+ assert_equal(a, b + 1)
+
+
+@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
+def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
+ a = Tensor([1, 2, 3]).to(dtype)
+ b = Tensor([1, 2, 3]).to(dtype)
+ assert_almost_equal(a, b)
+
+ with pytest.raises(AssertionError):
+ assert_almost_equal(a, b + 1)
+
+ assert_almost_equal(a, b + 1, atol=20)
+ assert_almost_equal(a, b + 1, rtol=20)
+
+ with pytest.raises(AssertionError):
+ assert_almost_equal(a, b + 1, atol=0.9)
+
+ with pytest.raises(AssertionError):
+ assert_almost_equal(a, b + 1, rtol=0.1)
diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py
index e4cf19f1..ef44fc4c 100644
--- a/candle-pyo3/tests/native/test_tensor.py
+++ b/candle-pyo3/tests/native/test_tensor.py
@@ -1,6 +1,7 @@
import candle
from candle import Tensor
from candle.utils import cuda_is_available
+from candle.testing import assert_equal
import pytest
@@ -77,6 +78,78 @@ def test_tensor_can_be_scliced_3d():
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
+def assert_bool(t: Tensor, expected: bool):
+ assert t.shape == ()
+ assert str(t.dtype) == str(candle.u8)
+ assert bool(t.values()) == expected
+
+
+def test_tensor_supports_equality_opperations_with_scalars():
+ t = Tensor(42.0)
+
+ assert_bool(t == 42.0, True)
+ assert_bool(t == 43.0, False)
+
+ assert_bool(t != 42.0, False)
+ assert_bool(t != 43.0, True)
+
+ assert_bool(t > 41.0, True)
+ assert_bool(t > 42.0, False)
+
+ assert_bool(t >= 41.0, True)
+ assert_bool(t >= 42.0, True)
+
+ assert_bool(t < 43.0, True)
+ assert_bool(t < 42.0, False)
+
+ assert_bool(t <= 43.0, True)
+ assert_bool(t <= 42.0, True)
+
+
+def test_tensor_supports_equality_opperations_with_tensors():
+ t = Tensor(42.0)
+ same = Tensor(42.0)
+ other = Tensor(43.0)
+
+ assert_bool(t == same, True)
+ assert_bool(t == other, False)
+
+ assert_bool(t != same, False)
+ assert_bool(t != other, True)
+
+ assert_bool(t > same, False)
+ assert_bool(t > other, False)
+
+ assert_bool(t >= same, True)
+ assert_bool(t >= other, False)
+
+ assert_bool(t < same, False)
+ assert_bool(t < other, True)
+
+ assert_bool(t <= same, True)
+ assert_bool(t <= other, True)
+
+
+def test_tensor_equality_opperations_can_broadcast():
+ # Create a decoder attention mask as a test case
+ # e.g.
+ # [[1,0,0]
+ # [1,1,0]
+ # [1,1,1]]
+ mask_cond = candle.Tensor([0, 1, 2])
+ mask = mask_cond < (mask_cond + 1).reshape((3, 1))
+ assert mask.shape == (3, 3)
+ assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8))
+
+
+def test_tensor_can_be_hashed():
+ t = Tensor(42.0)
+ other = Tensor(42.0)
+ # Hash should represent a unique tensor
+ assert hash(t) != hash(other)
+ assert hash(t) == hash(t)
+
+
def test_tensor_can_be_expanded_with_none():
t = candle.rand((12, 12))