diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-28 15:43:03 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-28 15:43:03 +0100 |
commit | 3f0d9fbb257baf94acde184de76eb9667e0fa025 (patch) | |
tree | 9bd3217971362a991faac24968f9bf77bf663476 | |
parent | cca699be6c8167f565067ceb3c940dd3c1d87503 (diff) | |
download | candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.tar.gz candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.tar.bz2 candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.zip |
Adapt the cuda bits.
-rw-r--r-- | candle-core/src/cpu_backend.rs | 19 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 157 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 2 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 14 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 4 |
5 files changed, 109 insertions, 87 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 9f0c8602..f1547b3c 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -101,14 +101,9 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>( } } -fn take_impl1<T: Copy>( - vs: &[T], - ids: &[u32], - layout: &Layout, - vocab_size: usize, - hidden_size: usize, -) -> Result<Vec<T>> { +fn take_impl1<T: Copy>(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result<Vec<T>> { // TODO: Optimize for the case where ids are contiguous. + let (vocab_size, hidden_size) = rhs_l.shape().r2()?; let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); for index in layout.strided_index() { let index = ids[index].try_into()?; @@ -610,15 +605,9 @@ impl CpuStorage { } } - pub(crate) fn embedding( - &self, - layout: &Layout, - vs: &Self, - hidden_size: usize, - vocab_size: usize, - ) -> Result<Self> { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { let ids = self.as_slice::<u32>()?; - map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) + map1!(rhs, take_impl1, ids, layout, rhs_l) } pub(crate) fn matmul( diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9cbf82be..f50d7cbb 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,4 +1,4 @@ -use crate::{CpuStorage, DType, Shape}; +use crate::{CpuStorage, DType, Layout, Shape}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig}; @@ -26,6 +26,9 @@ pub enum CudaError { #[error("internal error '{0}'")] InternalError(&'static str), + #[error("internal error '{0}'")] + WrappedError(Box<dyn std::error::Error>), + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] MatMulNonContiguous { lhs_stride: Vec<usize>, @@ -268,12 +271,14 @@ fn gemm_config<T>( alpha: T, beta: T, (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result<StridedBatchedConfig<T>> { // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm use cudarc::cublas::sys::cublasOperation_t; + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; @@ -352,19 +357,21 @@ impl CudaStorage { &self.device } - pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> { + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> { use cudarc::driver::DevicePtr; + let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let start_o = layout.start_offset(); let inp = match &self.slice { - CudaStorageSlice::U32(inp) => inp.device_ptr(), - CudaStorageSlice::BF16(inp) => inp.device_ptr(), - CudaStorageSlice::F16(inp) => inp.device_ptr(), - CudaStorageSlice::F32(inp) => inp.device_ptr(), - CudaStorageSlice::F64(inp) => inp.device_ptr(), + CudaStorageSlice::U32(inp) => inp.slice(start_o..).device_ptr(), + CudaStorageSlice::BF16(inp) => inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F16(inp) => inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F32(inp) => inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F64(inp) => inp.slice(start_o..).device_ptr(), }; let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; @@ -406,20 +413,16 @@ impl CudaStorage { }) } - pub(crate) fn affine_impl( - &self, - shape: &Shape, - stride: &[usize], - mul: f64, - add: f64, - ) -> Result<Self> { + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> { + let shape = layout.shape(); let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; let slice = match &self.slice { CudaStorageSlice::U32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<u32>(el_count) }?; @@ -429,6 +432,7 @@ impl CudaStorage { CudaStorageSlice::U32(out) } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<bf16>(el_count) }?; @@ -446,6 +450,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f16>(el_count) }?; @@ -463,6 +468,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f32>(el_count) }?; @@ -472,6 +478,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f64>(el_count) }?; @@ -485,7 +492,8 @@ impl CudaStorage { Ok(Self { slice, device }) } - pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> { + pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { + let shape = layout.shape(); let src_dims = shape.dims(); let el = shape.elem_count(); let mut dst_el = el; @@ -503,9 +511,10 @@ impl CudaStorage { .collect(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([src_dims, stride, &sum_dims_l, &sum_dims_s].concat())?; + let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; let slice = match &self.slice { CudaStorageSlice::U32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?; let out = dev.alloc_zeros::<u32>(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -514,6 +523,7 @@ impl CudaStorage { CudaStorageSlice::U32(out) } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?; let out = dev.alloc_zeros::<bf16>(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -522,6 +532,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?; let out = dev.alloc_zeros::<f16>(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -530,6 +541,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?; let out = dev.alloc_zeros::<f32>(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -538,6 +550,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?; let out = dev.alloc_zeros::<f64>(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -556,21 +569,19 @@ impl CudaStorage { )) } - pub(crate) fn unary_impl<U: crate::op::UnaryOp>( - &self, - shape: &Shape, - stride: &[usize], - ) -> Result<Self> { + pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> { + let shape = layout.shape(); let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; let slice = match &self.slice { CudaStorageSlice::U32(_arg) => { todo!("No unary kernels for u32"); } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<bf16>(el_count) }?; @@ -580,6 +591,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f16>(el_count) }?; @@ -589,6 +601,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f32>(el_count) }?; @@ -598,6 +611,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f64>(el_count) }?; @@ -614,17 +628,19 @@ impl CudaStorage { pub(crate) fn binary_impl<B: crate::op::BinaryOp>( &self, rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result<Self> { + let shape = lhs_l.shape(); let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dev = self.device(); - let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?; + let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?; let slice = match (&self.slice, &rhs.slice) { (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<bf16>(elem_count) }?; @@ -634,6 +650,8 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f16>(elem_count) }?; @@ -643,6 +661,8 @@ impl CudaStorage { CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f32>(elem_count) }?; @@ -652,6 +672,8 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?; let out = unsafe { dev.alloc::<f64>(elem_count) }?; @@ -661,6 +683,8 @@ impl CudaStorage { CudaStorageSlice::F64(out) } (CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?; let out = unsafe { dev.alloc::<u32>(elem_count) }?; @@ -708,28 +732,31 @@ impl CudaStorage { pub(crate) fn where_cond( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &Self, - stride_t: &[usize], + layout_t: &Layout, f: &Self, - stride_f: &[usize], + layout_f: &Layout, ) -> Result<Self> { let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice, + CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "where conditions should be u32", expected: DType::U32, got: self.dtype(), })?, }; + let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?; + let ds = + dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?; let slice = match (&t.slice, &f.slice) { (CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<bf16>(el) }?; @@ -739,6 +766,8 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } (CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f16>(el) }?; @@ -748,6 +777,8 @@ impl CudaStorage { CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f32>(el) }?; @@ -757,6 +788,8 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?; let out = unsafe { dev.alloc::<f64>(el) }?; @@ -766,6 +799,8 @@ impl CudaStorage { CudaStorageSlice::F64(out) } (CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?; let out = unsafe { dev.alloc::<u32>(el) }?; @@ -775,36 +810,35 @@ impl CudaStorage { CudaStorageSlice::U32(out) } // The dtypes should have been checked at this point so this is an internal error. - _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; let device = dev.clone(); Ok(Self { slice, device }) } - pub(crate) fn embedding_impl( - &self, - shape: &Shape, - stride: &[usize], - rhs: &Self, - h_size: usize, // hidden size - v_size: usize, // vocab size - ) -> Result<Self> { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice, + CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "embedding ids should be u32", expected: DType::U32, got: self.dtype(), })?, }; + let shape = layout.shape(); + let (v_size, h_size) = rhs_l + .shape() + .r2() + .map_err(|e| CudaError::WrappedError(Box::new(e)))?; let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; let slice = match &rhs.slice { // The kernels below assume that rhs is contiguous. CudaStorageSlice::U32(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<u32>(el * h_size) }?; @@ -814,6 +848,7 @@ impl CudaStorage { CudaStorageSlice::U32(out) } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<bf16>(el * h_size) }?; @@ -823,6 +858,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f16>(el * h_size) }?; @@ -832,6 +868,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f32>(el * h_size) }?; @@ -841,6 +878,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<f64>(el * h_size) }?; @@ -854,12 +892,12 @@ impl CudaStorage { Ok(Self { slice, device }) } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, rhs: &Self, (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result<Self> { let elem_count = b * m * n; let dev = &self.device; @@ -868,7 +906,9 @@ impl CudaStorage { todo!("bf16") } (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { - let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::<f16>(elem_count) }?; unsafe { self.device @@ -878,7 +918,9 @@ impl CudaStorage { CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::<f32>(elem_count) }?; unsafe { self.device @@ -888,7 +930,9 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::<f64>(elem_count) }?; unsafe { self.device @@ -907,13 +951,8 @@ impl CudaStorage { &self, dst: &mut Self, dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: &Layout, ) -> Result<()> { - if src_shape.rank() != src_stride.len() { - panic!("incoherent shape and strides {src_shape:?} {src_stride:?}") - } let dims = src_shape.dims(); let el_count = src_shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index ef079812..8193b1af 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -100,7 +100,7 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: usize, _: usize) -> Result<Self> { + pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 2c9624c7..7acf6dd0 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -167,26 +167,20 @@ impl Storage { (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), - op: "embedding", + op: "where", }), } } - pub(crate) fn embedding( - &self, - layout: &Layout, - rhs: &Self, - hidden_size: usize, - vocab_size: usize, - ) -> Result<Self> { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { self.same_device(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, rhs_l)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, rhs_l)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 93846160..f64bd6f2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -481,10 +481,10 @@ impl Tensor { } let ids_shape = ids.shape(); let seq_len = ids_shape.r1()?; - let (vocab_size, hidden_size) = rhs.shape().r2()?; + let (_, hidden_size) = rhs.shape().r2()?; let storage = ids .storage - .embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?; + .embedding(ids.layout(), &rhs.storage, rhs.layout())?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone())) |