summaryrefslogtreecommitdiff
path: root/candle-pyo3/tests
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-20 10:59:00 +0200
committerGitHub <noreply@github.com>2023-10-20 09:59:00 +0100
commitb43ab6cd1d7b128f2f9d7d8d3acc3a29c9d3b289 (patch)
tree3750b13e038f9ec3a6da604e2ee07224259d1c6b /candle-pyo3/tests
parent31ca4897bbff517156f7730b9562ac30061b39d5 (diff)
downloadcandle-b43ab6cd1d7b128f2f9d7d8d3acc3a29c9d3b289.tar.gz
candle-b43ab6cd1d7b128f2f9d7d8d3acc3a29c9d3b289.tar.bz2
candle-b43ab6cd1d7b128f2f9d7d8d3acc3a29c9d3b289.zip
PyO3: Add `None` and `Tensor` indexing to `candle.Tensor` (#1098)
* Add proper `None` and `tensor` indexing * Allow indexing via lists + allow tensor/list indexing outside of first dimension
Diffstat (limited to 'candle-pyo3/tests')
-rw-r--r--candle-pyo3/tests/native/test_tensor.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py
index 659423e0..e4cf19f1 100644
--- a/candle-pyo3/tests/native/test_tensor.py
+++ b/candle-pyo3/tests/native/test_tensor.py
@@ -55,6 +55,7 @@ def test_tensor_can_be_sliced():
assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0]
assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0]
assert t[-4:-2].values() == [5.0, 9.0]
+ assert t[...].values() == t.values()
def test_tensor_can_be_sliced_2d():
@@ -76,6 +77,43 @@ def test_tensor_can_be_scliced_3d():
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
+def test_tensor_can_be_expanded_with_none():
+ t = candle.rand((12, 12))
+
+ b = t[None]
+ assert b.shape == (1, 12, 12)
+ c = t[:, None, None, :]
+ assert c.shape == (12, 1, 1, 12)
+ d = t[None, :, None, :]
+ assert d.shape == (1, 12, 1, 12)
+ e = t[None, None, :, :]
+ assert e.shape == (1, 1, 12, 12)
+ f = t[:, :, None]
+ assert f.shape == (12, 12, 1)
+
+
+def test_tensor_can_be_index_via_tensor():
+ t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]])
+ indexed = t[candle.Tensor([0, 2])]
+ assert indexed.shape == (2, 4)
+ assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]]
+
+ indexed = t[:, candle.Tensor([0, 2])]
+ assert indexed.shape == (3, 2)
+ assert indexed.values() == [[1, 1], [3, 3], [5, 5]]
+
+
+def test_tensor_can_be_index_via_list():
+ t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]])
+ indexed = t[[0, 2]]
+ assert indexed.shape == (2, 4)
+ assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]]
+
+ indexed = t[:, [0, 2]]
+ assert indexed.shape == (3, 2)
+ assert indexed.values() == [[1, 1], [3, 3], [5, 5]]
+
+
def test_tensor_can_be_cast_via_to():
t = Tensor(42.0)
assert str(t.dtype) == str(candle.f32)