diff options
Diffstat (limited to 'candle-pyo3/tests/native/test_tensor.py')
-rw-r--r-- | candle-pyo3/tests/native/test_tensor.py | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index e4cf19f1..ef44fc4c 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -1,6 +1,7 @@ import candle from candle import Tensor from candle.utils import cuda_is_available +from candle.testing import assert_equal import pytest @@ -77,6 +78,78 @@ def test_tensor_can_be_scliced_3d(): assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] +def assert_bool(t: Tensor, expected: bool): + assert t.shape == () + assert str(t.dtype) == str(candle.u8) + assert bool(t.values()) == expected + + +def test_tensor_supports_equality_opperations_with_scalars(): + t = Tensor(42.0) + + assert_bool(t == 42.0, True) + assert_bool(t == 43.0, False) + + assert_bool(t != 42.0, False) + assert_bool(t != 43.0, True) + + assert_bool(t > 41.0, True) + assert_bool(t > 42.0, False) + + assert_bool(t >= 41.0, True) + assert_bool(t >= 42.0, True) + + assert_bool(t < 43.0, True) + assert_bool(t < 42.0, False) + + assert_bool(t <= 43.0, True) + assert_bool(t <= 42.0, True) + + +def test_tensor_supports_equality_opperations_with_tensors(): + t = Tensor(42.0) + same = Tensor(42.0) + other = Tensor(43.0) + + assert_bool(t == same, True) + assert_bool(t == other, False) + + assert_bool(t != same, False) + assert_bool(t != other, True) + + assert_bool(t > same, False) + assert_bool(t > other, False) + + assert_bool(t >= same, True) + assert_bool(t >= other, False) + + assert_bool(t < same, False) + assert_bool(t < other, True) + + assert_bool(t <= same, True) + assert_bool(t <= other, True) + + +def test_tensor_equality_opperations_can_broadcast(): + # Create a decoder attention mask as a test case + # e.g. + # [[1,0,0] + # [1,1,0] + # [1,1,1]] + mask_cond = candle.Tensor([0, 1, 2]) + mask = mask_cond < (mask_cond + 1).reshape((3, 1)) + assert mask.shape == (3, 3) + assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8)) + + +def test_tensor_can_be_hashed(): + t = Tensor(42.0) + other = Tensor(42.0) + # Hash should represent a unique tensor + assert hash(t) != hash(other) + assert hash(t) == hash(t) + + def test_tensor_can_be_expanded_with_none(): t = candle.rand((12, 12)) |