summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/testing/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/py_src/candle/testing/__init__.py')
-rw-r--r--candle-pyo3/py_src/candle/testing/__init__.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/testing/__init__.py b/candle-pyo3/py_src/candle/testing/__init__.py
new file mode 100644
index 00000000..240b635f
--- /dev/null
+++ b/candle-pyo3/py_src/candle/testing/__init__.py
@@ -0,0 +1,70 @@
+import candle
+from candle import Tensor
+
+
+_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])
+
+
+def _assert_tensor_metadata(
+ actual: Tensor,
+ expected: Tensor,
+ check_device: bool = True,
+ check_dtype: bool = True,
+ check_layout: bool = True,
+ check_stride: bool = False,
+):
+ if check_device:
+ assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"
+
+ if check_dtype:
+ assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"
+
+ if check_layout:
+ assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"
+
+ if check_stride:
+ assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"
+
+
+def assert_equal(
+ actual: Tensor,
+ expected: Tensor,
+ check_device: bool = True,
+ check_dtype: bool = True,
+ check_layout: bool = True,
+ check_stride: bool = False,
+):
+ """
+ Asserts that two tensors are exact equals.
+ """
+ _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
+ assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"
+
+
+def assert_almost_equal(
+ actual: Tensor,
+ expected: Tensor,
+ rtol=1e-05,
+ atol=1e-08,
+ check_device: bool = True,
+ check_dtype: bool = True,
+ check_layout: bool = True,
+ check_stride: bool = False,
+):
+ """
+ Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.
+
+ Computes: |actual - expected| ≤ atol + rtol x |expected|
+ """
+ _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
+
+ # Secure against overflow of u32 and u8 tensors
+ if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
+ actual = actual.to(candle.i64)
+ expected = expected.to(candle.i64)
+
+ diff = (actual - expected).abs()
+
+ threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)
+
+ assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"