summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backend.rs9
-rw-r--r--candle-core/src/backprop.rs5
-rw-r--r--candle-core/src/cpu_backend.rs12
-rw-r--r--candle-core/src/cuda_backend.rs11
-rw-r--r--candle-core/src/dummy_cuda_backend.rs11
-rw-r--r--candle-core/src/op.rs1
-rw-r--r--candle-core/src/storage.rs28
-rw-r--r--candle-core/src/tensor.rs23
8 files changed, 97 insertions, 3 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 1f5f45ab..2b873e6e 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -41,6 +41,15 @@ pub trait BackendStorage: Sized {
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
+ fn index_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self>;
fn matmul(
&self,
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index bfbb350e..7b493d31 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -38,7 +38,9 @@ impl Tensor {
nodes
} else if let Some(op) = node.op() {
match op {
- Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => {
+ Op::IndexAdd(t1, t2, t3, _)
+ | Op::CustomOp3(t1, t2, t3, _)
+ | Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t2, nodes, already_seen);
@@ -160,6 +162,7 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
+ Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
Op::IndexSelect(arg, indexes, dim) => {
let dim = *dim;
let sum_grad = grads.or_insert(arg)?;
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index a471e308..8e9b1d8e 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1532,6 +1532,18 @@ impl BackendStorage for CpuStorage {
IndexSelect { ids, ids_l, dim }.map(self, l)
}
+ fn index_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self> {
+ todo!()
+ }
+
fn matmul(
&self,
rhs: &Self,
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index cdbfd0c6..a5633836 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1064,6 +1064,17 @@ impl BackendStorage for CudaStorage {
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement index-select").into())
}
+ fn index_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self> {
+ Err(CudaError::InternalError("TODO: implement index-add").into())
+ }
fn matmul(
&self,
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index b8d4c727..633f146e 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -85,6 +85,17 @@ impl crate::backend::BackendStorage for CudaStorage {
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn index_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
fn matmul(
&self,
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index a33dd226..d36aa301 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -67,6 +67,7 @@ pub(crate) enum Op {
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
IndexSelect(Tensor, Tensor, usize),
+ IndexAdd(Tensor, Tensor, Tensor, usize),
WhereCond(Tensor, Tensor, Tensor),
#[allow(dead_code)]
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 2df21862..62e2d5e7 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -308,7 +308,7 @@ impl Storage {
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
self.same_device(rhs, "embedding")?;
match (self, rhs) {
- (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
+ (Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.embedding(layout, rhs, rhs_l)?;
Ok(Self::Cpu(storage))
}
@@ -325,6 +325,30 @@ impl Storage {
}
}
+ pub(crate) fn index_add(
+ &self,
+ l: &Layout,
+ indexes: &Self,
+ indexes_l: &Layout,
+ source: &Self,
+ source_l: &Layout,
+ d: usize,
+ ) -> Result<Self> {
+ self.same_device(indexes, "index-add")?;
+ self.same_device(source, "index-add")?;
+ match (self, indexes, source) {
+ (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
+ let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
+ Ok(Self::Cpu(storage))
+ }
+ (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
+ let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
+ Ok(Self::Cuda(storage))
+ }
+ _ => unreachable!(),
+ }
+ }
+
pub(crate) fn index_select(
&self,
rhs: &Self,
@@ -334,7 +358,7 @@ impl Storage {
) -> Result<Self> {
self.same_device(rhs, "index-select")?;
match (self, rhs) {
- (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
+ (Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cpu(storage))
}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e4e4ba6b..d4ee34f9 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -945,6 +945,29 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
+ pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "index-add")?;
+ let storage = self.storage().index_add(
+ self.layout(),
+ &indexes.storage(),
+ indexes.layout(),
+ &source.storage(),
+ source.layout(),
+ dim,
+ )?;
+ let op = if indexes.track_op() || self.track_op() {
+ Some(Op::IndexAdd(
+ self.clone(),
+ indexes.clone(),
+ source.clone(),
+ dim,
+ ))
+ } else {
+ None
+ };
+ Ok(from_storage(storage, self.shape(), op, false))
+ }
+
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-select")?;
let indexes_len = match indexes.dims() {