diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-02 05:42:11 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-02 05:42:11 +0100 |
commit | 4b3bd79fbd7ea562fef8f747f2dec224afad26da (patch) | |
tree | 57db1ca57be638518355918e224c69ae3b455a55 /candle-core/src | |
parent | cc76c63202ab936c08f6a6b9dcc2756c6a227f63 (diff) | |
download | candle-4b3bd79fbd7ea562fef8f747f2dec224afad26da.tar.gz candle-4b3bd79fbd7ea562fef8f747f2dec224afad26da.tar.bz2 candle-4b3bd79fbd7ea562fef8f747f2dec224afad26da.zip |
Remove the embedding ops in favor of index-select. (#299)
* Remove the embedding ops in favor of index-select.
* Also remove the cuda kernels.
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backend.rs | 1 | ||||
-rw-r--r-- | candle-core/src/backprop.rs | 4 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 73 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 46 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 3 | ||||
-rw-r--r-- | candle-core/src/op.rs | 1 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 20 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 28 |
8 files changed, 9 insertions, 167 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index cee1cad0..345db0e5 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -37,7 +37,6 @@ pub trait BackendStorage: Sized { _params: &crate::conv::ParamsConv1D, ) -> Result<Self>; - fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>; fn scatter_add( &self, diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index fd1650bb..f5cc8191 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -59,7 +59,6 @@ impl Tensor { | Op::Binary(lhs, rhs, _) | Op::Gather(lhs, rhs, _) | Op::IndexSelect(lhs, rhs, _) - | Op::Embedding(lhs, rhs) | Op::Matmul(lhs, rhs) => { let (tg, nodes) = walk(lhs, nodes, already_seen); track_grad |= tg; @@ -188,9 +187,6 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?; } - Op::Embedding(_lhs, _rhs) => { - Err(Error::BackwardNotSupported { op: "embedding" })? - } Op::Matmul(lhs, rhs) => { // Skipping checks, the op went ok, we can skip // the matmul size checks for now. diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 59c17387..8563721c 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -861,58 +861,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { } } -struct Embedding<'a, I: IntDType> { - vocab_size: usize, - hidden_size: usize, - ids: &'a [I], - ids_l: &'a Layout, -} - -impl<'a, I: IntDType> Map1 for Embedding<'a, I> { - fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> { - if !layout.is_contiguous() { - Err(Error::RequiresContiguous { op: "embedding" })? - } - let vs = &vs[layout.start_offset()..]; - let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size); - match self.ids_l.contiguous_offsets() { - Some((o1, o2)) => { - for index in self.ids[o1..o2].iter() { - let index = index.as_usize(); - if index >= self.vocab_size { - Err(Error::InvalidIndex { - index, - size: self.vocab_size, - op: "take", - } - .bt())? - } else { - let hidden_size = self.hidden_size; - values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); - } - } - } - None => { - for index in self.ids_l.strided_index() { - let index = self.ids[index].as_usize(); - if index >= self.vocab_size { - Err(Error::InvalidIndex { - index, - size: self.vocab_size, - op: "take", - } - .bt())? - } else { - let hidden_size = self.hidden_size; - values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); - } - } - } - } - Ok(values) - } -} - fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) { match src_l.strided_blocks() { crate::StridedBlocks::SingleBlock { start_offset, len } => { @@ -1664,27 +1612,6 @@ impl BackendStorage for CpuStorage { Conv1D(params).map(self, l, kernel, kernel_l) } - fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { - let (vocab_size, hidden_size) = rhs_l.shape().dims2()?; - match self { - Self::U8(ids) => Embedding { - vocab_size, - hidden_size, - ids, - ids_l, - } - .map(rhs, rhs_l), - Self::U32(ids) => Embedding { - vocab_size, - hidden_size, - ids, - ids_l, - } - .map(rhs, rhs_l), - _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "embedding")), - } - } - fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { match ids { Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 6c98cd0a..7b4b358d 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -690,46 +690,6 @@ impl<U: UnaryOpT> Map1 for U { } } -struct Embedding<'a>(&'a CudaStorage, &'a Layout); -impl<'a> Map1 for Embedding<'a> { - fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( - &self, - rhs: &CudaSlice<T>, - dev: &CudaDevice, - rhs_l: &Layout, - ) -> Result<CudaSlice<T>> { - let ids_l = &self.1; - let (name, ids) = match &self.0.slice { - CudaStorageSlice::U32(slice) => { - ("emb_u32", *slice.slice(ids_l.start_offset()..).device_ptr()) - } - CudaStorageSlice::U8(slice) => { - ("emb_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) - } - _ => Err(CudaError::UnexpectedDType { - msg: "embedding ids should be u8 or u32", - expected: DType::U32, - got: self.0.dtype(), - }) - .w()?, - }; - let shape = ids_l.shape(); - let (v_size, h_size) = rhs_l.shape().dims2()?; - let dims = shape.dims(); - let el = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el as u32); - let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?; - let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?; - let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; - Ok(out) - } -} - struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize); impl<'a> Map1 for IndexSelect<'a> { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( @@ -1421,12 +1381,6 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { - let device = self.device().clone(); - let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?; - Ok(Self { slice, device }) - } - fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { let device = self.device().clone(); let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 1213c502..17d4a22e 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -75,9 +75,6 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> { - Err(Error::NotCompiledWithCudaSupport) - } fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 4f489f30..ba8d2fb4 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -65,7 +65,6 @@ pub enum Op { // The third argument is the reduced shape with `keepdim=true`. Reduce(Tensor, ReduceOp, Vec<usize>), Matmul(Tensor, Tensor), - Embedding(Tensor, Tensor), Gather(Tensor, Tensor, usize), ScatterAdd(Tensor, Tensor, Tensor, usize), IndexSelect(Tensor, Tensor, usize), diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 545f549b..1e1ef305 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -295,26 +295,6 @@ impl Storage { } } - pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { - self.same_device(rhs, "embedding")?; - match (self, rhs) { - (Self::Cpu(lhs), Self::Cpu(rhs)) => { - let storage = lhs.embedding(layout, rhs, rhs_l)?; - Ok(Self::Cpu(storage)) - } - (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.embedding(layout, rhs, rhs_l)?; - Ok(Self::Cuda(storage)) - } - (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { - lhs: lhs.device().location(), - rhs: rhs.device().location(), - op: "embedding", - } - .bt()), - } - } - pub(crate) fn gather( &self, l: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 060e8792..c326a5ac 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -842,45 +842,35 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } - /// Returns a tensor with the values from the `rhs` tensor at the index corresponding to the + /// Returns a tensor with the values from the `self` tensor at the index corresponding to the /// values hold in the `ids` tensor. /// /// # Arguments /// + /// * `self` - A tensor with dimensions `v, h`. /// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive). - /// * `rhs` - A tensor with dimensions `v, h`. /// /// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the /// vocabulary size, and `h` the hidden size. /// /// ```rust /// use candle::{Tensor, Device}; - /// let rhs = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; /// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?; - /// let emb = Tensor::embedding(&ids, &rhs)?; + /// let emb = values.embedding(&ids)?; /// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]); /// # Ok::<(), candle::Error>(()) /// ``` - pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> { - if !rhs.is_contiguous() { - Err(Error::RequiresContiguous { op: "embedding" }.bt())? - } else if rhs.rank() != 2 || ids.rank() != 1 { + pub fn embedding(&self, ids: &Self) -> Result<Self> { + if self.rank() != 2 || ids.rank() != 1 { Err(Error::ShapeMismatchBinaryOp { - lhs: ids.shape().clone(), - rhs: rhs.shape().clone(), + lhs: self.shape().clone(), + rhs: ids.shape().clone(), op: "embedding", } .bt())? } - let ids_shape = ids.shape(); - let seq_len = ids_shape.dims1()?; - let (_, hidden_size) = rhs.dims2()?; - let storage = ids - .storage() - .embedding(ids.layout(), &rhs.storage(), rhs.layout())?; - let shape: Shape = (seq_len, hidden_size).into(); - let op = BackpropOp::new2(ids, rhs, Op::Embedding); - Ok(from_storage(storage, shape, op, false)) + self.index_select(ids, 0) } pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> { |