diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-14 08:47:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-14 07:47:07 +0100 |
commit | d6447ad635bc450ef1f15ca7a4424c0f86e7a90a (patch) | |
tree | 2629de94c8b6bf383d67e577357a8cb2e85155b7 | |
parent | 49d3f7f70814bd0e8b569f93bb76419306359251 (diff) | |
download | candle-d6447ad635bc450ef1f15ca7a4424c0f86e7a90a.tar.gz candle-d6447ad635bc450ef1f15ca7a4424c0f86e7a90a.tar.bz2 candle-d6447ad635bc450ef1f15ca7a4424c0f86e7a90a.zip |
Tensor based indexing. (#842)
-rw-r--r-- | candle-core/src/indexer.rs | 39 |
1 files changed, 38 insertions, 1 deletions
diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index 2b6d694b..7b84d316 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -46,19 +46,31 @@ impl Tensor { current_dim += 1; out } + TensorIndexer::IndexSelect(indexes) => { + if indexes.rank() != 1 { + crate::bail!("multi-dimensional tensor indexing is not supported") + } + let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?; + current_dim += 1; + out + } + TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"), }; } Ok(x) } } -#[derive(Debug, Clone)] +#[derive(Debug)] /// Generic structure used to index a slice of the tensor pub enum TensorIndexer { /// This selects the elemnts for which an index has some specific value. Select(usize), /// This is a regular slice, purely indexing a chunk of the tensor Narrow(Bound<usize>, Bound<usize>), + /// Indexing via a 1d tensor + IndexSelect(Tensor), + Err(Error), } impl From<usize> for TensorIndexer { @@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer { } } +impl From<&[u32]> for TensorIndexer { + fn from(index: &[u32]) -> Self { + match Tensor::new(index, &crate::Device::Cpu) { + Ok(tensor) => TensorIndexer::IndexSelect(tensor), + Err(e) => TensorIndexer::Err(e), + } + } +} + +impl From<Vec<u32>> for TensorIndexer { + fn from(index: Vec<u32>) -> Self { + let len = index.len(); + match Tensor::from_vec(index, len, &crate::Device::Cpu) { + Ok(tensor) => TensorIndexer::IndexSelect(tensor), + Err(e) => TensorIndexer::Err(e), + } + } +} + +impl From<&Tensor> for TensorIndexer { + fn from(tensor: &Tensor) -> Self { + TensorIndexer::IndexSelect(tensor.clone()) + } +} + macro_rules! impl_from_range { ($range_type:ty) => { impl From<$range_type> for TensorIndexer { |