diff options
-rw-r--r-- | candle-core/src/backend.rs | 9 | ||||
-rw-r--r-- | candle-core/src/backprop.rs | 5 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 12 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 11 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 11 | ||||
-rw-r--r-- | candle-core/src/op.rs | 1 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 28 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 23 |
8 files changed, 97 insertions, 3 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 1f5f45ab..2b873e6e 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -41,6 +41,15 @@ pub trait BackendStorage: Sized { fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>; fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>; + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result<Self>; fn matmul( &self, diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index bfbb350e..7b493d31 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -38,7 +38,9 @@ impl Tensor { nodes } else if let Some(op) = node.op() { match op { - Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => { + Op::IndexAdd(t1, t2, t3, _) + | Op::CustomOp3(t1, t2, t3, _) + | Op::WhereCond(t1, t2, t3) => { let (tg, nodes) = walk(t1, nodes, already_seen); track_grad |= tg; let (tg, nodes) = walk(t2, nodes, already_seen); @@ -160,6 +162,7 @@ impl Tensor { *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, + Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?, Op::IndexSelect(arg, indexes, dim) => { let dim = *dim; let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index a471e308..8e9b1d8e 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1532,6 +1532,18 @@ impl BackendStorage for CpuStorage { IndexSelect { ids, ids_l, dim }.map(self, l) } + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result<Self> { + todo!() + } + fn matmul( &self, rhs: &Self, diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index cdbfd0c6..a5633836 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1064,6 +1064,17 @@ impl BackendStorage for CudaStorage { fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> { Err(CudaError::InternalError("TODO: implement index-select").into()) } + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result<Self> { + Err(CudaError::InternalError("TODO: implement index-add").into()) + } fn matmul( &self, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index b8d4c727..633f146e 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -85,6 +85,17 @@ impl crate::backend::BackendStorage for CudaStorage { fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result<Self> { + Err(Error::NotCompiledWithCudaSupport) + } fn matmul( &self, diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index a33dd226..d36aa301 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -67,6 +67,7 @@ pub(crate) enum Op { Matmul(Tensor, Tensor), Embedding(Tensor, Tensor), IndexSelect(Tensor, Tensor, usize), + IndexAdd(Tensor, Tensor, Tensor, usize), WhereCond(Tensor, Tensor, Tensor), #[allow(dead_code)] diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 2df21862..62e2d5e7 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -308,7 +308,7 @@ impl Storage { pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { self.same_device(rhs, "embedding")?; match (self, rhs) { - (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { + (Self::Cpu(lhs), Self::Cpu(rhs)) => { let storage = lhs.embedding(layout, rhs, rhs_l)?; Ok(Self::Cpu(storage)) } @@ -325,6 +325,30 @@ impl Storage { } } + pub(crate) fn index_add( + &self, + l: &Layout, + indexes: &Self, + indexes_l: &Layout, + source: &Self, + source_l: &Layout, + d: usize, + ) -> Result<Self> { + self.same_device(indexes, "index-add")?; + self.same_device(source, "index-add")?; + match (self, indexes, source) { + (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => { + let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => { + let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Cuda(storage)) + } + _ => unreachable!(), + } + } + pub(crate) fn index_select( &self, rhs: &Self, @@ -334,7 +358,7 @@ impl Storage { ) -> Result<Self> { self.same_device(rhs, "index-select")?; match (self, rhs) { - (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { + (Self::Cpu(lhs), Self::Cpu(rhs)) => { let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?; Ok(Self::Cpu(storage)) } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e4e4ba6b..d4ee34f9 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -945,6 +945,29 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } + pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> { + let dim = dim.to_index(self.shape(), "index-add")?; + let storage = self.storage().index_add( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + let op = if indexes.track_op() || self.track_op() { + Some(Op::IndexAdd( + self.clone(), + indexes.clone(), + source.clone(), + dim, + )) + } else { + None + }; + Ok(from_storage(storage, self.shape(), op, false)) + } + pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> { let dim = dim.to_index(self.shape(), "index-select")?; let indexes_len = match indexes.dims() { |