summaryrefslogtreecommitdiff
path: root/candle-pyo3/tests/native/test_tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/tests/native/test_tensor.py')
-rw-r--r--candle-pyo3/tests/native/test_tensor.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py
index e4cf19f1..ef44fc4c 100644
--- a/candle-pyo3/tests/native/test_tensor.py
+++ b/candle-pyo3/tests/native/test_tensor.py
@@ -1,6 +1,7 @@
import candle
from candle import Tensor
from candle.utils import cuda_is_available
+from candle.testing import assert_equal
import pytest
@@ -77,6 +78,78 @@ def test_tensor_can_be_scliced_3d():
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
+def assert_bool(t: Tensor, expected: bool):
+ assert t.shape == ()
+ assert str(t.dtype) == str(candle.u8)
+ assert bool(t.values()) == expected
+
+
+def test_tensor_supports_equality_opperations_with_scalars():
+ t = Tensor(42.0)
+
+ assert_bool(t == 42.0, True)
+ assert_bool(t == 43.0, False)
+
+ assert_bool(t != 42.0, False)
+ assert_bool(t != 43.0, True)
+
+ assert_bool(t > 41.0, True)
+ assert_bool(t > 42.0, False)
+
+ assert_bool(t >= 41.0, True)
+ assert_bool(t >= 42.0, True)
+
+ assert_bool(t < 43.0, True)
+ assert_bool(t < 42.0, False)
+
+ assert_bool(t <= 43.0, True)
+ assert_bool(t <= 42.0, True)
+
+
+def test_tensor_supports_equality_opperations_with_tensors():
+ t = Tensor(42.0)
+ same = Tensor(42.0)
+ other = Tensor(43.0)
+
+ assert_bool(t == same, True)
+ assert_bool(t == other, False)
+
+ assert_bool(t != same, False)
+ assert_bool(t != other, True)
+
+ assert_bool(t > same, False)
+ assert_bool(t > other, False)
+
+ assert_bool(t >= same, True)
+ assert_bool(t >= other, False)
+
+ assert_bool(t < same, False)
+ assert_bool(t < other, True)
+
+ assert_bool(t <= same, True)
+ assert_bool(t <= other, True)
+
+
+def test_tensor_equality_opperations_can_broadcast():
+ # Create a decoder attention mask as a test case
+ # e.g.
+ # [[1,0,0]
+ # [1,1,0]
+ # [1,1,1]]
+ mask_cond = candle.Tensor([0, 1, 2])
+ mask = mask_cond < (mask_cond + 1).reshape((3, 1))
+ assert mask.shape == (3, 3)
+ assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8))
+
+
+def test_tensor_can_be_hashed():
+ t = Tensor(42.0)
+ other = Tensor(42.0)
+ # Hash should represent a unique tensor
+ assert hash(t) != hash(other)
+ assert hash(t) == hash(t)
+
+
def test_tensor_can_be_expanded_with_none():
t = candle.rand((12, 12))