diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-22 08:21:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-22 07:21:28 +0100 |
commit | 52c5d8c087f6a2ee91b807467860eb3e96bb6267 (patch) | |
tree | 8b5738ae1f5e7fb662f58b192238469d8d28f25f /candle-core/src/cuda_backend.rs | |
parent | 6eeea1b04e5bf52a77b3f7e35a5c51e13b383848 (diff) | |
download | candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.gz candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.bz2 candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.zip |
Add the gather op. (#219)
* Start adding gather.
* Gather cpu implementation + use in simple training.
* Add scatter_add for the gradient of gather.
* Simple cpu implementation of scatter_add.
* Use gather in the simple-training backprop.
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index a5633836..43bfef2d 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1064,6 +1064,20 @@ impl BackendStorage for CudaStorage { fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> { Err(CudaError::InternalError("TODO: implement index-select").into()) } + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> { + Err(CudaError::InternalError("TODO: implement gather").into()) + } + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result<Self> { + Err(CudaError::InternalError("TODO: implement scatter-add").into()) + } fn index_add( &self, _: &Layout, |