summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/testing/__init__.py
blob: 240b635f28140ea68cf803b34676a141c2c572f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import candle
from candle import Tensor


_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])


def _assert_tensor_metadata(
    actual: Tensor,
    expected: Tensor,
    check_device: bool = True,
    check_dtype: bool = True,
    check_layout: bool = True,
    check_stride: bool = False,
):
    if check_device:
        assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"

    if check_dtype:
        assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"

    if check_layout:
        assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"

    if check_stride:
        assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"


def assert_equal(
    actual: Tensor,
    expected: Tensor,
    check_device: bool = True,
    check_dtype: bool = True,
    check_layout: bool = True,
    check_stride: bool = False,
):
    """
    Asserts that two tensors are exact equals.
    """
    _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
    assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"


def assert_almost_equal(
    actual: Tensor,
    expected: Tensor,
    rtol=1e-05,
    atol=1e-08,
    check_device: bool = True,
    check_dtype: bool = True,
    check_layout: bool = True,
    check_stride: bool = False,
):
    """
    Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.

    Computes: |actual - expected| ≤ atol + rtol x |expected|
    """
    _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)

    # Secure against overflow of u32 and u8 tensors
    if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
        actual = actual.to(candle.i64)
        expected = expected.to(candle.i64)

    diff = (actual - expected).abs()

    threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)

    assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"