diff options
Diffstat (limited to 'candle-pyo3/tests/bindings/test_testing.py')
-rw-r--r-- | candle-pyo3/tests/bindings/test_testing.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/candle-pyo3/tests/bindings/test_testing.py b/candle-pyo3/tests/bindings/test_testing.py new file mode 100644 index 00000000..db2fd3f7 --- /dev/null +++ b/candle-pyo3/tests/bindings/test_testing.py @@ -0,0 +1,33 @@ +import candle +from candle import Tensor +from candle.testing import assert_equal, assert_almost_equal +import pytest + + +@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64]) +def test_assert_equal_asserts_correctly(dtype: candle.DType): + a = Tensor([1, 2, 3]).to(dtype) + b = Tensor([1, 2, 3]).to(dtype) + assert_equal(a, b) + + with pytest.raises(AssertionError): + assert_equal(a, b + 1) + + +@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64]) +def test_assert_almost_equal_asserts_correctly(dtype: candle.DType): + a = Tensor([1, 2, 3]).to(dtype) + b = Tensor([1, 2, 3]).to(dtype) + assert_almost_equal(a, b) + + with pytest.raises(AssertionError): + assert_almost_equal(a, b + 1) + + assert_almost_equal(a, b + 1, atol=20) + assert_almost_equal(a, b + 1, rtol=20) + + with pytest.raises(AssertionError): + assert_almost_equal(a, b + 1, atol=0.9) + + with pytest.raises(AssertionError): + assert_almost_equal(a, b + 1, rtol=0.1) |