diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-22 08:21:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-22 07:21:28 +0100 |
commit | 52c5d8c087f6a2ee91b807467860eb3e96bb6267 (patch) | |
tree | 8b5738ae1f5e7fb662f58b192238469d8d28f25f /candle-core/src/op.rs | |
parent | 6eeea1b04e5bf52a77b3f7e35a5c51e13b383848 (diff) | |
download | candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.gz candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.bz2 candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.zip |
Add the gather op. (#219)
* Start adding gather.
* Gather cpu implementation + use in simple training.
* Add scatter_add for the gradient of gather.
* Simple cpu implementation of scatter_add.
* Use gather in the simple-training backprop.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r-- | candle-core/src/op.rs | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index d36aa301..de5094bd 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -66,6 +66,8 @@ pub(crate) enum Op { Reduce(Tensor, ReduceOp, Vec<usize>), Matmul(Tensor, Tensor), Embedding(Tensor, Tensor), + Gather(Tensor, Tensor, usize), + ScatterAdd(Tensor, Tensor, Tensor, usize), IndexSelect(Tensor, Tensor, usize), IndexAdd(Tensor, Tensor, Tensor, usize), WhereCond(Tensor, Tensor, Tensor), |