summaryrefslogtreecommitdiff
path: root/candle-pyo3/tests/bindings/test_testing.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/tests/bindings/test_testing.py')
-rw-r--r--candle-pyo3/tests/bindings/test_testing.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/candle-pyo3/tests/bindings/test_testing.py b/candle-pyo3/tests/bindings/test_testing.py
new file mode 100644
index 00000000..db2fd3f7
--- /dev/null
+++ b/candle-pyo3/tests/bindings/test_testing.py
@@ -0,0 +1,33 @@
+import candle
+from candle import Tensor
+from candle.testing import assert_equal, assert_almost_equal
+import pytest
+
+
+@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
+def test_assert_equal_asserts_correctly(dtype: candle.DType):
+ a = Tensor([1, 2, 3]).to(dtype)
+ b = Tensor([1, 2, 3]).to(dtype)
+ assert_equal(a, b)
+
+ with pytest.raises(AssertionError):
+ assert_equal(a, b + 1)
+
+
+@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
+def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
+ a = Tensor([1, 2, 3]).to(dtype)
+ b = Tensor([1, 2, 3]).to(dtype)
+ assert_almost_equal(a, b)
+
+ with pytest.raises(AssertionError):
+ assert_almost_equal(a, b + 1)
+
+ assert_almost_equal(a, b + 1, atol=20)
+ assert_almost_equal(a, b + 1, rtol=20)
+
+ with pytest.raises(AssertionError):
+ assert_almost_equal(a, b + 1, atol=0.9)
+
+ with pytest.raises(AssertionError):
+ assert_almost_equal(a, b + 1, rtol=0.1)