summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/indexer.rs39
1 files changed, 38 insertions, 1 deletions
diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs
index 2b6d694b..7b84d316 100644
--- a/candle-core/src/indexer.rs
+++ b/candle-core/src/indexer.rs
@@ -46,19 +46,31 @@ impl Tensor {
current_dim += 1;
out
}
+ TensorIndexer::IndexSelect(indexes) => {
+ if indexes.rank() != 1 {
+ crate::bail!("multi-dimensional tensor indexing is not supported")
+ }
+ let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
+ current_dim += 1;
+ out
+ }
+ TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
};
}
Ok(x)
}
}
-#[derive(Debug, Clone)]
+#[derive(Debug)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
+ /// Indexing via a 1d tensor
+ IndexSelect(Tensor),
+ Err(Error),
}
impl From<usize> for TensorIndexer {
@@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer {
}
}
+impl From<&[u32]> for TensorIndexer {
+ fn from(index: &[u32]) -> Self {
+ match Tensor::new(index, &crate::Device::Cpu) {
+ Ok(tensor) => TensorIndexer::IndexSelect(tensor),
+ Err(e) => TensorIndexer::Err(e),
+ }
+ }
+}
+
+impl From<Vec<u32>> for TensorIndexer {
+ fn from(index: Vec<u32>) -> Self {
+ let len = index.len();
+ match Tensor::from_vec(index, len, &crate::Device::Cpu) {
+ Ok(tensor) => TensorIndexer::IndexSelect(tensor),
+ Err(e) => TensorIndexer::Err(e),
+ }
+ }
+}
+
+impl From<&Tensor> for TensorIndexer {
+ fn from(tensor: &Tensor) -> Self {
+ TensorIndexer::IndexSelect(tensor.clone())
+ }
+}
+
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {