summaryrefslogtreecommitdiff
path: root/candle-core/src/op.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-22 08:21:28 +0200
committerGitHub <noreply@github.com>2023-07-22 07:21:28 +0100
commit52c5d8c087f6a2ee91b807467860eb3e96bb6267 (patch)
tree8b5738ae1f5e7fb662f58b192238469d8d28f25f /candle-core/src/op.rs
parent6eeea1b04e5bf52a77b3f7e35a5c51e13b383848 (diff)
downloadcandle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.gz
candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.bz2
candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.zip
Add the gather op. (#219)
* Start adding gather. * Gather cpu implementation + use in simple training. * Add scatter_add for the gradient of gather. * Simple cpu implementation of scatter_add. * Use gather in the simple-training backprop.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r--candle-core/src/op.rs2
1 files changed, 2 insertions, 0 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index d36aa301..de5094bd 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -66,6 +66,8 @@ pub(crate) enum Op {
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
+ Gather(Tensor, Tensor, usize),
+ ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),
WhereCond(Tensor, Tensor, Tensor),