summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs113
1 files changed, 40 insertions, 73 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index dc0d51bf..7d06dd72 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -373,6 +373,44 @@ impl<U: crate::op::UnaryOp> Map1 for U {
}
}
+struct Embedding<'a>(&'a CudaStorage, &'a Layout);
+impl<'a> Map1 for Embedding<'a> {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ rhs: &CudaSlice<T>,
+ dev: &CudaDevice,
+ rhs_l: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let ids_l = &self.1;
+ let ids = match &self.0.slice {
+ CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
+ _ => Err(CudaError::UnexpectedDType {
+ msg: "embedding ids should be u32",
+ expected: DType::U32,
+ got: self.0.dtype(),
+ })?,
+ };
+ let ids = &ids;
+ let shape = ids_l.shape();
+ let (v_size, h_size) = rhs_l
+ .shape()
+ .r2()
+ .map_err(|e| CudaError::WrappedError(Box::new(e)))?;
+ let dims = shape.dims();
+ let el = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(el as u32);
+ let ds = dev.htod_copy([dims, ids_l.stride()].concat())?;
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(el * h_size) }?;
+ let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }?;
+ Ok(out)
+ }
+}
+
fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>,
src_l: &Layout,
@@ -760,79 +798,8 @@ impl CudaStorage {
}
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
- let ids = match &self.slice {
- CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
- _ => Err(CudaError::UnexpectedDType {
- msg: "embedding ids should be u32",
- expected: DType::U32,
- got: self.dtype(),
- })?,
- };
- let ids = &ids;
- let shape = layout.shape();
- let (v_size, h_size) = rhs_l
- .shape()
- .r2()
- .map_err(|e| CudaError::WrappedError(Box::new(e)))?;
- let dims = shape.dims();
- let el = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(el as u32);
- let dev = self.device();
- let ds = dev.htod_copy([dims, layout.stride()].concat())?;
- let slice = match &rhs.slice {
- // The kernels below assume that rhs is contiguous.
- CudaStorageSlice::U32(arg) => {
- let arg = &arg.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
- let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::U32(out)
- }
- CudaStorageSlice::BF16(arg) => {
- let arg = &arg.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
- let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::BF16(out)
- }
- CudaStorageSlice::F16(arg) => {
- let arg = &arg.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
- let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F16(out)
- }
- CudaStorageSlice::F32(arg) => {
- let arg = &arg.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
- let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F32(out)
- }
- CudaStorageSlice::F64(arg) => {
- let arg = &arg.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
- let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F64(out)
- }
- };
- let device = dev.clone();
+ let device = self.device().clone();
+ let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
Ok(Self { slice, device })
}