diff options
Diffstat (limited to 'candle-pyo3/py_src/candle/testing/__init__.py')
-rw-r--r-- | candle-pyo3/py_src/candle/testing/__init__.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/testing/__init__.py b/candle-pyo3/py_src/candle/testing/__init__.py new file mode 100644 index 00000000..240b635f --- /dev/null +++ b/candle-pyo3/py_src/candle/testing/__init__.py @@ -0,0 +1,70 @@ +import candle +from candle import Tensor + + +_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)]) + + +def _assert_tensor_metadata( + actual: Tensor, + expected: Tensor, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, +): + if check_device: + assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}" + + if check_dtype: + assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}" + + if check_layout: + assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}" + + if check_stride: + assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}" + + +def assert_equal( + actual: Tensor, + expected: Tensor, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, +): + """ + Asserts that two tensors are exact equals. + """ + _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride) + assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}" + + +def assert_almost_equal( + actual: Tensor, + expected: Tensor, + rtol=1e-05, + atol=1e-08, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, +): + """ + Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance. + + Computes: |actual - expected| ≤ atol + rtol x |expected| + """ + _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride) + + # Secure against overflow of u32 and u8 tensors + if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES: + actual = actual.to(candle.i64) + expected = expected.to(candle.i64) + + diff = (actual - expected).abs() + + threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected) + + assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great" |