diff options
Diffstat (limited to 'candle-pyo3/tests')
-rw-r--r-- | candle-pyo3/tests/native/test_tensor.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index 659423e0..e4cf19f1 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -55,6 +55,7 @@ def test_tensor_can_be_sliced(): assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0] assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0] assert t[-4:-2].values() == [5.0, 9.0] + assert t[...].values() == t.values() def test_tensor_can_be_sliced_2d(): @@ -76,6 +77,43 @@ def test_tensor_can_be_scliced_3d(): assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] +def test_tensor_can_be_expanded_with_none(): + t = candle.rand((12, 12)) + + b = t[None] + assert b.shape == (1, 12, 12) + c = t[:, None, None, :] + assert c.shape == (12, 1, 1, 12) + d = t[None, :, None, :] + assert d.shape == (1, 12, 1, 12) + e = t[None, None, :, :] + assert e.shape == (1, 1, 12, 12) + f = t[:, :, None] + assert f.shape == (12, 12, 1) + + +def test_tensor_can_be_index_via_tensor(): + t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]]) + indexed = t[candle.Tensor([0, 2])] + assert indexed.shape == (2, 4) + assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]] + + indexed = t[:, candle.Tensor([0, 2])] + assert indexed.shape == (3, 2) + assert indexed.values() == [[1, 1], [3, 3], [5, 5]] + + +def test_tensor_can_be_index_via_list(): + t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]]) + indexed = t[[0, 2]] + assert indexed.shape == (2, 4) + assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]] + + indexed = t[:, [0, 2]] + assert indexed.shape == (3, 2) + assert indexed.values() == [[1, 1], [3, 3], [5, 5]] + + def test_tensor_can_be_cast_via_to(): t = Tensor(42.0) assert str(t.dtype) == str(candle.f32) |