summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-20 10:59:00 +0200
committerGitHub <noreply@github.com>2023-10-20 09:59:00 +0100
commitb43ab6cd1d7b128f2f9d7d8d3acc3a29c9d3b289 (patch)
tree3750b13e038f9ec3a6da604e2ee07224259d1c6b /candle-pyo3
parent31ca4897bbff517156f7730b9562ac30061b39d5 (diff)
downloadcandle-b43ab6cd1d7b128f2f9d7d8d3acc3a29c9d3b289.tar.gz
candle-b43ab6cd1d7b128f2f9d7d8d3acc3a29c9d3b289.tar.bz2
candle-b43ab6cd1d7b128f2f9d7d8d3acc3a29c9d3b289.zip
PyO3: Add `None` and `Tensor` indexing to `candle.Tensor` (#1098)
* Add proper `None` and `tensor` indexing * Allow indexing via lists + allow tensor/list indexing outside of first dimension
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/src/lib.rs126
-rw-r--r--candle-pyo3/tests/native/test_tensor.py38
2 files changed, 132 insertions, 32 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index f9fdc712..f16d8c1b 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -201,6 +201,8 @@ enum Indexer {
Index(usize),
Slice(usize, usize),
Elipsis,
+ Expand,
+ IndexSelect(Tensor),
}
#[pymethods]
@@ -450,7 +452,7 @@ impl PyTensor {
let mut indexers: Vec<Indexer> = vec![];
let dims = self.0.shape().dims();
- let to_absolute_index = |index: isize, current_dim: usize| {
+ fn to_absolute_index(index: isize, current_dim: usize, dims: &[usize]) -> PyResult<usize> {
// Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
let actual_index = if index < 0 {
dims[current_dim] as isize + index
@@ -460,48 +462,92 @@ impl PyTensor {
// Check that the index is in range
if actual_index < 0 || actual_index >= dims[current_dim] as isize {
- return Err(PyTypeError::new_err(format!(
+ return Err(PyValueError::new_err(format!(
"index out of range for dimension '{i}' with indexer '{value}'",
i = current_dim,
value = index
)));
}
Ok(actual_index as usize)
- };
- if let Ok(index) = idx.extract(py) {
- // Handle a single index e.g. tensor[0] or tensor[-1]
- indexers.push(Indexer::Index(to_absolute_index(index, 0)?));
- } else if let Ok(slice) = idx.downcast::<pyo3::types::PySlice>(py) {
- // Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
- let index = slice.indices(dims[0] as c_long)?;
- indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
- } else if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
- // Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1]
-
- if tuple.len() > dims.len() {
- return Err(PyTypeError::new_err("provided too many indices"));
+ }
+
+ fn extract_indexer(
+ py_indexer: &PyAny,
+ current_dim: usize,
+ dims: &[usize],
+ index_argument_count: usize,
+ ) -> PyResult<(Indexer, usize)> {
+ if let Ok(index) = py_indexer.extract() {
+ // Handle a single index e.g. tensor[0] or tensor[-1]
+ Ok((
+ Indexer::Index(to_absolute_index(index, current_dim, dims)?),
+ current_dim + 1,
+ ))
+ } else if let Ok(slice) = py_indexer.downcast::<pyo3::types::PySlice>() {
+ // Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
+ let index = slice.indices(dims[current_dim] as c_long)?;
+ Ok((
+ Indexer::Slice(index.start as usize, index.stop as usize),
+ current_dim + 1,
+ ))
+ } else if let Ok(tensor) = py_indexer.extract::<PyTensor>() {
+ // Handle a tensor as indices e.g. tensor[tensor([0,1])]
+ let t = tensor.0;
+ if t.rank() != 1 {
+ return Err(PyTypeError::new_err(
+ "multi-dimensional tensor indexing is not supported",
+ ));
+ }
+ Ok((Indexer::IndexSelect(t), current_dim + 1))
+ } else if let Ok(list) = py_indexer.downcast::<pyo3::types::PyList>() {
+ // Handle a list of indices e.g. tensor[[0,1]]
+ let mut indexes = vec![];
+ for item in list.iter() {
+ let index = item.extract::<i64>()?;
+ indexes.push(index);
+ }
+ Ok((
+ Indexer::IndexSelect(
+ Tensor::from_vec(indexes, list.len(), &Device::Cpu).map_err(wrap_err)?,
+ ),
+ current_dim + 1,
+ ))
+ } else if py_indexer.is_ellipsis() {
+ // Handle '...' e.g. tensor[..., 0]
+ if current_dim > 0 {
+ return Err(PyTypeError::new_err(
+ "Ellipsis ('...') can only be used at the start of an indexing operation",
+ ));
+ }
+ Ok((Indexer::Elipsis, dims.len() - (index_argument_count - 1)))
+ } else if py_indexer.is_none() {
+ // Handle None e.g. tensor[None, 0]
+ Ok((Indexer::Expand, current_dim))
+ } else {
+ Err(PyTypeError::new_err(format!(
+ "unsupported indexer {}",
+ py_indexer
+ )))
}
+ }
- for (i, item) in tuple.iter().enumerate() {
- if item.is_ellipsis() {
- // Handle '...' e.g. tensor[..., 0]
+ if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
+ let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();
- if i > 0 {
- return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation"));
- }
- indexers.push(Indexer::Elipsis);
- } else if let Ok(slice) = item.downcast::<pyo3::types::PySlice>() {
- // Handle slice
- let index = slice.indices(dims[i] as c_long)?;
- indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
- } else if let Ok(index) = item.extract::<isize>() {
- indexers.push(Indexer::Index(to_absolute_index(index, i)?));
- } else {
- return Err(PyTypeError::new_err("unsupported index"));
- }
+ if not_none_count > dims.len() {
+ return Err(PyValueError::new_err("provided too many indices"));
+ }
+
+ let mut current_dim = 0;
+ for item in tuple.iter() {
+ let (indexer, new_current_dim) =
+ extract_indexer(item, current_dim, dims, not_none_count)?;
+ current_dim = new_current_dim;
+ indexers.push(indexer);
}
} else {
- return Err(PyTypeError::new_err("unsupported index"));
+ let (indexer, _) = extract_indexer(idx.downcast::<PyAny>(py)?, 0, dims, 1)?;
+ indexers.push(indexer);
}
let mut x = self.0.clone();
@@ -526,6 +572,22 @@ impl PyTensor {
current_dim += dims.len() - (indexers.len() - 1);
x
}
+ Indexer::Expand => {
+ // Expand is a special case, it means that a new dimension should be added => unsqueeze and advance the current_dim
+ let out = x.unsqueeze(current_dim).map_err(wrap_err)?;
+ current_dim += 1;
+ out
+ }
+ Indexer::IndexSelect(indexes) => {
+ let out = x
+ .index_select(
+ &indexes.to_device(x.device()).map_err(wrap_err)?,
+ current_dim,
+ )
+ .map_err(wrap_err)?;
+ current_dim += 1;
+ out
+ }
}
}
diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py
index 659423e0..e4cf19f1 100644
--- a/candle-pyo3/tests/native/test_tensor.py
+++ b/candle-pyo3/tests/native/test_tensor.py
@@ -55,6 +55,7 @@ def test_tensor_can_be_sliced():
assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0]
assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0]
assert t[-4:-2].values() == [5.0, 9.0]
+ assert t[...].values() == t.values()
def test_tensor_can_be_sliced_2d():
@@ -76,6 +77,43 @@ def test_tensor_can_be_scliced_3d():
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
+def test_tensor_can_be_expanded_with_none():
+ t = candle.rand((12, 12))
+
+ b = t[None]
+ assert b.shape == (1, 12, 12)
+ c = t[:, None, None, :]
+ assert c.shape == (12, 1, 1, 12)
+ d = t[None, :, None, :]
+ assert d.shape == (1, 12, 1, 12)
+ e = t[None, None, :, :]
+ assert e.shape == (1, 1, 12, 12)
+ f = t[:, :, None]
+ assert f.shape == (12, 12, 1)
+
+
+def test_tensor_can_be_index_via_tensor():
+ t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]])
+ indexed = t[candle.Tensor([0, 2])]
+ assert indexed.shape == (2, 4)
+ assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]]
+
+ indexed = t[:, candle.Tensor([0, 2])]
+ assert indexed.shape == (3, 2)
+ assert indexed.values() == [[1, 1], [3, 3], [5, 5]]
+
+
+def test_tensor_can_be_index_via_list():
+ t = candle.Tensor([[1, 2, 1, 2], [3, 4, 3, 4], [5, 6, 5, 6]])
+ indexed = t[[0, 2]]
+ assert indexed.shape == (2, 4)
+ assert indexed.values() == [[1, 2, 1, 2], [5, 6, 5, 6]]
+
+ indexed = t[:, [0, 2]]
+ assert indexed.shape == (3, 2)
+ assert indexed.values() == [[1, 1], [3, 3], [5, 5]]
+
+
def test_tensor_can_be_cast_via_to():
t = Tensor(42.0)
assert str(t.dtype) == str(candle.f32)