summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backend.rs1
-rw-r--r--candle-core/src/backprop.rs4
-rw-r--r--candle-core/src/cpu_backend.rs73
-rw-r--r--candle-core/src/cuda_backend.rs46
-rw-r--r--candle-core/src/dummy_cuda_backend.rs3
-rw-r--r--candle-core/src/op.rs1
-rw-r--r--candle-core/src/storage.rs20
-rw-r--r--candle-core/src/tensor.rs28
-rw-r--r--candle-core/tests/tensor_tests.rs2
-rw-r--r--candle-examples/examples/musicgen/encodec_model.rs2
-rw-r--r--candle-kernels/src/indexing.cu40
11 files changed, 11 insertions, 209 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index cee1cad0..345db0e5 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -37,7 +37,6 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
- fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index fd1650bb..f5cc8191 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -59,7 +59,6 @@ impl Tensor {
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
- | Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
@@ -188,9 +187,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
}
- Op::Embedding(_lhs, _rhs) => {
- Err(Error::BackwardNotSupported { op: "embedding" })?
- }
Op::Matmul(lhs, rhs) => {
// Skipping checks, the op went ok, we can skip
// the matmul size checks for now.
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 59c17387..8563721c 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -861,58 +861,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
}
}
-struct Embedding<'a, I: IntDType> {
- vocab_size: usize,
- hidden_size: usize,
- ids: &'a [I],
- ids_l: &'a Layout,
-}
-
-impl<'a, I: IntDType> Map1 for Embedding<'a, I> {
- fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
- if !layout.is_contiguous() {
- Err(Error::RequiresContiguous { op: "embedding" })?
- }
- let vs = &vs[layout.start_offset()..];
- let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
- match self.ids_l.contiguous_offsets() {
- Some((o1, o2)) => {
- for index in self.ids[o1..o2].iter() {
- let index = index.as_usize();
- if index >= self.vocab_size {
- Err(Error::InvalidIndex {
- index,
- size: self.vocab_size,
- op: "take",
- }
- .bt())?
- } else {
- let hidden_size = self.hidden_size;
- values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
- }
- }
- }
- None => {
- for index in self.ids_l.strided_index() {
- let index = self.ids[index].as_usize();
- if index >= self.vocab_size {
- Err(Error::InvalidIndex {
- index,
- size: self.vocab_size,
- op: "take",
- }
- .bt())?
- } else {
- let hidden_size = self.hidden_size;
- values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
- }
- }
- }
- }
- Ok(values)
- }
-}
-
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
match src_l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
@@ -1664,27 +1612,6 @@ impl BackendStorage for CpuStorage {
Conv1D(params).map(self, l, kernel, kernel_l)
}
- fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
- let (vocab_size, hidden_size) = rhs_l.shape().dims2()?;
- match self {
- Self::U8(ids) => Embedding {
- vocab_size,
- hidden_size,
- ids,
- ids_l,
- }
- .map(rhs, rhs_l),
- Self::U32(ids) => Embedding {
- vocab_size,
- hidden_size,
- ids,
- ids_l,
- }
- .map(rhs, rhs_l),
- _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "embedding")),
- }
- }
-
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 6c98cd0a..7b4b358d 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -690,46 +690,6 @@ impl<U: UnaryOpT> Map1 for U {
}
}
-struct Embedding<'a>(&'a CudaStorage, &'a Layout);
-impl<'a> Map1 for Embedding<'a> {
- fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
- &self,
- rhs: &CudaSlice<T>,
- dev: &CudaDevice,
- rhs_l: &Layout,
- ) -> Result<CudaSlice<T>> {
- let ids_l = &self.1;
- let (name, ids) = match &self.0.slice {
- CudaStorageSlice::U32(slice) => {
- ("emb_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
- }
- CudaStorageSlice::U8(slice) => {
- ("emb_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
- }
- _ => Err(CudaError::UnexpectedDType {
- msg: "embedding ids should be u8 or u32",
- expected: DType::U32,
- got: self.0.dtype(),
- })
- .w()?,
- };
- let shape = ids_l.shape();
- let (v_size, h_size) = rhs_l.shape().dims2()?;
- let dims = shape.dims();
- let el = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(el as u32);
- let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
- let rhs = &rhs.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
- let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }.w()?;
- Ok(out)
- }
-}
-
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map1 for IndexSelect<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@@ -1421,12 +1381,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
- fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
- let device = self.device().clone();
- let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
- Ok(Self { slice, device })
- }
-
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 1213c502..17d4a22e 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -75,9 +75,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
- fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
- Err(Error::NotCompiledWithCudaSupport)
- }
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 4f489f30..ba8d2fb4 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -65,7 +65,6 @@ pub enum Op {
// The third argument is the reduced shape with `keepdim=true`.
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
- Embedding(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 545f549b..1e1ef305 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -295,26 +295,6 @@ impl Storage {
}
}
- pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
- self.same_device(rhs, "embedding")?;
- match (self, rhs) {
- (Self::Cpu(lhs), Self::Cpu(rhs)) => {
- let storage = lhs.embedding(layout, rhs, rhs_l)?;
- Ok(Self::Cpu(storage))
- }
- (Self::Cuda(lhs), Self::Cuda(rhs)) => {
- let storage = lhs.embedding(layout, rhs, rhs_l)?;
- Ok(Self::Cuda(storage))
- }
- (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
- lhs: lhs.device().location(),
- rhs: rhs.device().location(),
- op: "embedding",
- }
- .bt()),
- }
- }
-
pub(crate) fn gather(
&self,
l: &Layout,
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 060e8792..c326a5ac 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -842,45 +842,35 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
- /// Returns a tensor with the values from the `rhs` tensor at the index corresponding to the
+ /// Returns a tensor with the values from the `self` tensor at the index corresponding to the
/// values hold in the `ids` tensor.
///
/// # Arguments
///
+ /// * `self` - A tensor with dimensions `v, h`.
/// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
- /// * `rhs` - A tensor with dimensions `v, h`.
///
/// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
/// vocabulary size, and `h` the hidden size.
///
/// ```rust
/// use candle::{Tensor, Device};
- /// let rhs = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
+ /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
- /// let emb = Tensor::embedding(&ids, &rhs)?;
+ /// let emb = values.embedding(&ids)?;
/// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
/// # Ok::<(), candle::Error>(())
/// ```
- pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
- if !rhs.is_contiguous() {
- Err(Error::RequiresContiguous { op: "embedding" }.bt())?
- } else if rhs.rank() != 2 || ids.rank() != 1 {
+ pub fn embedding(&self, ids: &Self) -> Result<Self> {
+ if self.rank() != 2 || ids.rank() != 1 {
Err(Error::ShapeMismatchBinaryOp {
- lhs: ids.shape().clone(),
- rhs: rhs.shape().clone(),
+ lhs: self.shape().clone(),
+ rhs: ids.shape().clone(),
op: "embedding",
}
.bt())?
}
- let ids_shape = ids.shape();
- let seq_len = ids_shape.dims1()?;
- let (_, hidden_size) = rhs.dims2()?;
- let storage = ids
- .storage()
- .embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
- let shape: Shape = (seq_len, hidden_size).into();
- let op = BackpropOp::new2(ids, rhs, Op::Embedding);
- Ok(from_storage(storage, shape, op, false))
+ self.index_select(ids, 0)
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index a8702df7..2147759d 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -534,7 +534,7 @@ fn cat(device: &Device) -> Result<()> {
fn embeddings(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
- let hs = Tensor::embedding(&ids, &t)?;
+ let hs = t.embedding(&ids)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs
index 2ef6f20f..eaf4ca05 100644
--- a/candle-examples/examples/musicgen/encodec_model.rs
+++ b/candle-examples/examples/musicgen/encodec_model.rs
@@ -142,7 +142,7 @@ impl EncodecEuclideanCodebook {
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
- let quantize = Tensor::embedding(embed_ind, &self.embed)?;
+ let quantize = self.embed.embedding(embed_ind)?;
Ok(quantize)
}
}
diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu
index 359db498..7723d3bc 100644
--- a/candle-kernels/src/indexing.cu
+++ b/candle-kernels/src/indexing.cu
@@ -3,32 +3,6 @@
#include "cuda_utils.cuh"
#include<stdint.h>
-#define EMB_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
-extern "C" __global__ void FN_NAME( \
- const size_t numel, \
- const size_t num_dims, \
- const size_t *info, \
- const INDEX_TYPENAME *ids, \
- const TYPENAME *inp, \
- TYPENAME *out, \
- const size_t h_size, \
- const size_t v_size \
-) { \
- const size_t *dims = info; \
- const size_t *strides = info + num_dims; \
- if (is_contiguous(num_dims, dims, strides)) { \
- for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
- memcpy(&out[i * h_size], &inp[ids[i] * h_size], h_size * sizeof(TYPENAME)); \
- } \
- } \
- else { \
- for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
- unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
- memcpy(&out[i * h_size], &inp[ids[strided_i] * h_size], h_size * sizeof(TYPENAME)); \
- } \
- } \
-} \
-
template<typename T, typename I>
__device__ void index_select(
const size_t numel,
@@ -177,8 +151,6 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800
-EMB_OP(__nv_bfloat16, uint32_t, emb_u32_bf16)
-EMB_OP(__nv_bfloat16, uint8_t, emb_u8_bf16)
IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16)
GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16)
@@ -190,8 +162,6 @@ SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
#endif
#if __CUDA_ARCH__ >= 530
-EMB_OP(__half, uint32_t, emb_u32_f16)
-EMB_OP(__half, uint8_t, emb_u8_f16)
IS_OP(__half, uint32_t, is_u32_f16)
IS_OP(__half, uint8_t, is_u8_f16)
GATHER_OP(__half, uint32_t, gather_u32_f16)
@@ -202,16 +172,6 @@ SA_OP(__half, uint32_t, sa_u32_f16)
SA_OP(__half, uint8_t, sa_u8_f16)
#endif
-EMB_OP(float, uint32_t, emb_u32_f32)
-EMB_OP(double, uint32_t, emb_u32_f64)
-EMB_OP(uint8_t, uint32_t, emb_u32_u8)
-EMB_OP(uint32_t, uint32_t, emb_u32_u32)
-
-EMB_OP(float, uint8_t, emb_u8_f32)
-EMB_OP(double, uint8_t, emb_u8_f64)
-EMB_OP(uint8_t, uint8_t, emb_u8_u8)
-EMB_OP(uint32_t, uint8_t, emb_u8_u32)
-
IS_OP(float, uint32_t, is_u32_f32)
IS_OP(double, uint32_t, is_u32_f64)
IS_OP(uint8_t, uint32_t, is_u32_u8)