summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-02 05:42:11 +0100
committerGitHub <noreply@github.com>2023-08-02 05:42:11 +0100
commit4b3bd79fbd7ea562fef8f747f2dec224afad26da (patch)
tree57db1ca57be638518355918e224c69ae3b455a55 /candle-core/src
parentcc76c63202ab936c08f6a6b9dcc2756c6a227f63 (diff)
downloadcandle-4b3bd79fbd7ea562fef8f747f2dec224afad26da.tar.gz
candle-4b3bd79fbd7ea562fef8f747f2dec224afad26da.tar.bz2
candle-4b3bd79fbd7ea562fef8f747f2dec224afad26da.zip
Remove the embedding ops in favor of index-select. (#299)
* Remove the embedding ops in favor of index-select. * Also remove the cuda kernels.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backend.rs1
-rw-r--r--candle-core/src/backprop.rs4
-rw-r--r--candle-core/src/cpu_backend.rs73
-rw-r--r--candle-core/src/cuda_backend.rs46
-rw-r--r--candle-core/src/dummy_cuda_backend.rs3
-rw-r--r--candle-core/src/op.rs1
-rw-r--r--candle-core/src/storage.rs20
-rw-r--r--candle-core/src/tensor.rs28
8 files changed, 9 insertions, 167 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index cee1cad0..345db0e5 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -37,7 +37,6 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
- fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index fd1650bb..f5cc8191 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -59,7 +59,6 @@ impl Tensor {
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
- | Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
@@ -188,9 +187,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
}
- Op::Embedding(_lhs, _rhs) => {
- Err(Error::BackwardNotSupported { op: "embedding" })?
- }
Op::Matmul(lhs, rhs) => {
// Skipping checks, the op went ok, we can skip
// the matmul size checks for now.
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 59c17387..8563721c 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -861,58 +861,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
}
}
-struct Embedding<'a, I: IntDType> {
- vocab_size: usize,
- hidden_size: usize,
- ids: &'a [I],
- ids_l: &'a Layout,
-}
-
-impl<'a, I: IntDType> Map1 for Embedding<'a, I> {
- fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
- if !layout.is_contiguous() {
- Err(Error::RequiresContiguous { op: "embedding" })?
- }
- let vs = &vs[layout.start_offset()..];
- let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
- match self.ids_l.contiguous_offsets() {
- Some((o1, o2)) => {
- for index in self.ids[o1..o2].iter() {
- let index = index.as_usize();
- if index >= self.vocab_size {
- Err(Error::InvalidIndex {
- index,
- size: self.vocab_size,
- op: "take",
- }
- .bt())?
- } else {
- let hidden_size = self.hidden_size;
- values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
- }
- }
- }
- None => {
- for index in self.ids_l.strided_index() {
- let index = self.ids[index].as_usize();
- if index >= self.vocab_size {
- Err(Error::InvalidIndex {
- index,
- size: self.vocab_size,
- op: "take",
- }
- .bt())?
- } else {
- let hidden_size = self.hidden_size;
- values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
- }
- }
- }
- }
- Ok(values)
- }
-}
-
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
match src_l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
@@ -1664,27 +1612,6 @@ impl BackendStorage for CpuStorage {
Conv1D(params).map(self, l, kernel, kernel_l)
}
- fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
- let (vocab_size, hidden_size) = rhs_l.shape().dims2()?;
- match self {
- Self::U8(ids) => Embedding {
- vocab_size,
- hidden_size,
- ids,
- ids_l,
- }
- .map(rhs, rhs_l),
- Self::U32(ids) => Embedding {
- vocab_size,
- hidden_size,
- ids,
- ids_l,
- }
- .map(rhs, rhs_l),
- _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "embedding")),
- }
- }
-
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 6c98cd0a..7b4b358d 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -690,46 +690,6 @@ impl<U: UnaryOpT> Map1 for U {
}
}
-struct Embedding<'a>(&'a CudaStorage, &'a Layout);
-impl<'a> Map1 for Embedding<'a> {
- fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
- &self,
- rhs: &CudaSlice<T>,
- dev: &CudaDevice,
- rhs_l: &Layout,
- ) -> Result<CudaSlice<T>> {
- let ids_l = &self.1;
- let (name, ids) = match &self.0.slice {
- CudaStorageSlice::U32(slice) => {
- ("emb_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
- }
- CudaStorageSlice::U8(slice) => {
- ("emb_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
- }
- _ => Err(CudaError::UnexpectedDType {
- msg: "embedding ids should be u8 or u32",
- expected: DType::U32,
- got: self.0.dtype(),
- })
- .w()?,
- };
- let shape = ids_l.shape();
- let (v_size, h_size) = rhs_l.shape().dims2()?;
- let dims = shape.dims();
- let el = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(el as u32);
- let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
- let rhs = &rhs.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
- let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }.w()?;
- Ok(out)
- }
-}
-
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map1 for IndexSelect<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@@ -1421,12 +1381,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
- fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
- let device = self.device().clone();
- let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
- Ok(Self { slice, device })
- }
-
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 1213c502..17d4a22e 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -75,9 +75,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
- fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
- Err(Error::NotCompiledWithCudaSupport)
- }
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 4f489f30..ba8d2fb4 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -65,7 +65,6 @@ pub enum Op {
// The third argument is the reduced shape with `keepdim=true`.
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
- Embedding(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 545f549b..1e1ef305 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -295,26 +295,6 @@ impl Storage {
}
}
- pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
- self.same_device(rhs, "embedding")?;
- match (self, rhs) {
- (Self::Cpu(lhs), Self::Cpu(rhs)) => {
- let storage = lhs.embedding(layout, rhs, rhs_l)?;
- Ok(Self::Cpu(storage))
- }
- (Self::Cuda(lhs), Self::Cuda(rhs)) => {
- let storage = lhs.embedding(layout, rhs, rhs_l)?;
- Ok(Self::Cuda(storage))
- }
- (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
- lhs: lhs.device().location(),
- rhs: rhs.device().location(),
- op: "embedding",
- }
- .bt()),
- }
- }
-
pub(crate) fn gather(
&self,
l: &Layout,
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 060e8792..c326a5ac 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -842,45 +842,35 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
- /// Returns a tensor with the values from the `rhs` tensor at the index corresponding to the
+ /// Returns a tensor with the values from the `self` tensor at the index corresponding to the
/// values hold in the `ids` tensor.
///
/// # Arguments
///
+ /// * `self` - A tensor with dimensions `v, h`.
/// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
- /// * `rhs` - A tensor with dimensions `v, h`.
///
/// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
/// vocabulary size, and `h` the hidden size.
///
/// ```rust
/// use candle::{Tensor, Device};
- /// let rhs = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
+ /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
- /// let emb = Tensor::embedding(&ids, &rhs)?;
+ /// let emb = values.embedding(&ids)?;
/// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
/// # Ok::<(), candle::Error>(())
/// ```
- pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
- if !rhs.is_contiguous() {
- Err(Error::RequiresContiguous { op: "embedding" }.bt())?
- } else if rhs.rank() != 2 || ids.rank() != 1 {
+ pub fn embedding(&self, ids: &Self) -> Result<Self> {
+ if self.rank() != 2 || ids.rank() != 1 {
Err(Error::ShapeMismatchBinaryOp {
- lhs: ids.shape().clone(),
- rhs: rhs.shape().clone(),
+ lhs: self.shape().clone(),
+ rhs: ids.shape().clone(),
op: "embedding",
}
.bt())?
}
- let ids_shape = ids.shape();
- let seq_len = ids_shape.dims1()?;
- let (_, hidden_size) = rhs.dims2()?;
- let storage = ids
- .storage()
- .embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
- let shape: Shape = (seq_len, hidden_size).into();
- let op = BackpropOp::new2(ids, rhs, Op::Embedding);
- Ok(from_storage(storage, shape, op, false))
+ self.index_select(ids, 0)
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {