summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-28 15:53:23 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-28 15:53:23 +0100
commit6c9e6b5a99d4070be5c20d7c383e0ef7e3228260 (patch)
treee49da5272c2c6f8d42a5bb5ad2d0f54a39e62979 /candle-core/src/cuda_backend.rs
parent3f0d9fbb257baf94acde184de76eb9667e0fa025 (diff)
downloadcandle-6c9e6b5a99d4070be5c20d7c383e0ef7e3228260.tar.gz
candle-6c9e6b5a99d4070be5c20d7c383e0ef7e3228260.tar.bz2
candle-6c9e6b5a99d4070be5c20d7c383e0ef7e3228260.zip
Get the cuda tests to pass.
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs49
1 files changed, 29 insertions, 20 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index f50d7cbb..9d9a5f99 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -27,7 +27,7 @@ pub enum CudaError {
InternalError(&'static str),
#[error("internal error '{0}'")]
- WrappedError(Box<dyn std::error::Error>),
+ WrappedError(Box<dyn std::error::Error + 'static + std::marker::Send + std::marker::Sync>),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
@@ -245,13 +245,14 @@ enum CudaStorageSlice {
fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>,
- src_offset: usize,
+ src_l: &Layout,
dst: &'a mut CudaSlice<T>,
dst_offset: usize,
) -> (
cudarc::driver::CudaView<'a, T>,
cudarc::driver::CudaViewMut<'a, T>,
) {
+ let src_offset = src_l.start_offset();
let to_copy = dst
.len()
.saturating_sub(dst_offset)
@@ -366,13 +367,18 @@ impl CudaStorage {
let dev = self.device();
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let start_o = layout.start_offset();
+ // This returns an i64 rather than a &i64, this is useful to get around some temporary
+ // lifetime issue and is safe as long as self.slice does not go out of scope before inp
+ // is used.
let inp = match &self.slice {
- CudaStorageSlice::U32(inp) => inp.slice(start_o..).device_ptr(),
- CudaStorageSlice::BF16(inp) => inp.slice(start_o..).device_ptr(),
- CudaStorageSlice::F16(inp) => inp.slice(start_o..).device_ptr(),
- CudaStorageSlice::F32(inp) => inp.slice(start_o..).device_ptr(),
- CudaStorageSlice::F64(inp) => inp.slice(start_o..).device_ptr(),
+ CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
+ CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
+ CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
+ CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
+ CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(),
};
+ let inp = &inp;
+
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
let slice = match dtype {
@@ -739,13 +745,14 @@ impl CudaStorage {
layout_f: &Layout,
) -> Result<Self> {
let ids = match &self.slice {
- CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..),
+ CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
+ let ids = &ids;
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
@@ -818,13 +825,14 @@ impl CudaStorage {
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let ids = match &self.slice {
- CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..),
+ CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
+ let ids = &ids;
let shape = layout.shape();
let (v_size, h_size) = rhs_l
.shape()
@@ -953,15 +961,16 @@ impl CudaStorage {
dst_offset: usize,
src_l: &Layout,
) -> Result<()> {
+ let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
- let ds = dev.htod_copy([dims, src_stride].concat())?;
+ let ds = dev.htod_copy([dims, src_l.stride()].concat())?;
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
- let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
- if src_shape.is_contiguous(src_stride) {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
@@ -972,8 +981,8 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
- let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
- if src_shape.is_contiguous(src_stride) {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
@@ -984,8 +993,8 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
- let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
- if src_shape.is_contiguous(src_stride) {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
@@ -996,8 +1005,8 @@ impl CudaStorage {
}
}
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
- let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
- if src_shape.is_contiguous(src_stride) {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
@@ -1008,8 +1017,8 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
- let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
- if src_shape.is_contiguous(src_stride) {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;