summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-06-28 15:59:53 +0100
committerGitHub <noreply@github.com>2023-06-28 15:59:53 +0100
commit0cfa21f26a88820fba91bb8ff02cf850eeeb97c3 (patch)
treeefc1279fc9ba273425689e79ac5577801b1bddae /candle-core/src/cuda_backend.rs
parent8b4b2d1830e6fb5aed2c410256bb4e7076e5007d (diff)
parent6c9e6b5a99d4070be5c20d7c383e0ef7e3228260 (diff)
downloadcandle-0cfa21f26a88820fba91bb8ff02cf850eeeb97c3.tar.gz
candle-0cfa21f26a88820fba91bb8ff02cf850eeeb97c3.tar.bz2
candle-0cfa21f26a88820fba91bb8ff02cf850eeeb97c3.zip
Merge pull request #27 from LaurentMazare/layout-refactor
Refactor the stride/shape handling
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs190
1 files changed, 119 insertions, 71 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 9cbf82be..9d9a5f99 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1,4 +1,4 @@
-use crate::{CpuStorage, DType, Shape};
+use crate::{CpuStorage, DType, Layout, Shape};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
@@ -26,6 +26,9 @@ pub enum CudaError {
#[error("internal error '{0}'")]
InternalError(&'static str),
+ #[error("internal error '{0}'")]
+ WrappedError(Box<dyn std::error::Error + 'static + std::marker::Send + std::marker::Sync>),
+
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
@@ -242,13 +245,14 @@ enum CudaStorageSlice {
fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>,
- src_offset: usize,
+ src_l: &Layout,
dst: &'a mut CudaSlice<T>,
dst_offset: usize,
) -> (
cudarc::driver::CudaView<'a, T>,
cudarc::driver::CudaViewMut<'a, T>,
) {
+ let src_offset = src_l.start_offset();
let to_copy = dst
.len()
.saturating_sub(dst_offset)
@@ -268,12 +272,14 @@ fn gemm_config<T>(
alpha: T,
beta: T,
(b, m, n, k): (usize, usize, usize, usize),
- lhs_stride: &[usize],
- rhs_stride: &[usize],
+ lhs_l: &Layout,
+ rhs_l: &Layout,
) -> Result<StridedBatchedConfig<T>> {
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
use cudarc::cublas::sys::cublasOperation_t;
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
@@ -352,20 +358,27 @@ impl CudaStorage {
&self.device
}
- pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
+ pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
use cudarc::driver::DevicePtr;
+ let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
- let ds = dev.htod_copy([dims, stride].concat())?;
+ let ds = dev.htod_copy([dims, layout.stride()].concat())?;
+ let start_o = layout.start_offset();
+ // This returns an i64 rather than a &i64, this is useful to get around some temporary
+ // lifetime issue and is safe as long as self.slice does not go out of scope before inp
+ // is used.
let inp = match &self.slice {
- CudaStorageSlice::U32(inp) => inp.device_ptr(),
- CudaStorageSlice::BF16(inp) => inp.device_ptr(),
- CudaStorageSlice::F16(inp) => inp.device_ptr(),
- CudaStorageSlice::F32(inp) => inp.device_ptr(),
- CudaStorageSlice::F64(inp) => inp.device_ptr(),
+ CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
+ CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
+ CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
+ CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
+ CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(),
};
+ let inp = &inp;
+
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
let slice = match dtype {
@@ -406,20 +419,16 @@ impl CudaStorage {
})
}
- pub(crate) fn affine_impl(
- &self,
- shape: &Shape,
- stride: &[usize],
- mul: f64,
- add: f64,
- ) -> Result<Self> {
+ pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
+ let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = self.device();
- let ds = dev.htod_copy([dims, stride].concat())?;
+ let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el_count) }?;
@@ -429,6 +438,7 @@ impl CudaStorage {
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
@@ -446,6 +456,7 @@ impl CudaStorage {
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el_count) }?;
@@ -463,6 +474,7 @@ impl CudaStorage {
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el_count) }?;
@@ -472,6 +484,7 @@ impl CudaStorage {
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el_count) }?;
@@ -485,7 +498,8 @@ impl CudaStorage {
Ok(Self { slice, device })
}
- pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> {
+ pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
+ let shape = layout.shape();
let src_dims = shape.dims();
let el = shape.elem_count();
let mut dst_el = el;
@@ -503,9 +517,10 @@ impl CudaStorage {
.collect();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
- let ds = dev.htod_copy([src_dims, stride, &sum_dims_l, &sum_dims_s].concat())?;
+ let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
let out = dev.alloc_zeros::<u32>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
@@ -514,6 +529,7 @@ impl CudaStorage {
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
let out = dev.alloc_zeros::<bf16>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
@@ -522,6 +538,7 @@ impl CudaStorage {
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f16>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
@@ -530,6 +547,7 @@ impl CudaStorage {
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f32>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
@@ -538,6 +556,7 @@ impl CudaStorage {
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f64>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
@@ -556,21 +575,19 @@ impl CudaStorage {
))
}
- pub(crate) fn unary_impl<U: crate::op::UnaryOp>(
- &self,
- shape: &Shape,
- stride: &[usize],
- ) -> Result<Self> {
+ pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
+ let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
- let ds = dev.htod_copy([dims, stride].concat())?;
+ let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(_arg) => {
todo!("No unary kernels for u32");
}
CudaStorageSlice::BF16(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
@@ -580,6 +597,7 @@ impl CudaStorage {
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el_count) }?;
@@ -589,6 +607,7 @@ impl CudaStorage {
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el_count) }?;
@@ -598,6 +617,7 @@ impl CudaStorage {
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
+ let arg = &arg.slice(layout.start_offset()..);
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el_count) }?;
@@ -614,17 +634,19 @@ impl CudaStorage {
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
&self,
rhs: &Self,
- shape: &Shape,
- lhs_stride: &[usize],
- rhs_stride: &[usize],
+ lhs_l: &Layout,
+ rhs_l: &Layout,
) -> Result<Self> {
+ let shape = lhs_l.shape();
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dev = self.device();
- let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
+ let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
let slice = match (&self.slice, &rhs.slice) {
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
+ let lhs = &lhs.slice(lhs_l.start_offset()..);
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
@@ -634,6 +656,8 @@ impl CudaStorage {
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
+ let lhs = &lhs.slice(lhs_l.start_offset()..);
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
@@ -643,6 +667,8 @@ impl CudaStorage {
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
+ let lhs = &lhs.slice(lhs_l.start_offset()..);
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
@@ -652,6 +678,8 @@ impl CudaStorage {
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
+ let lhs = &lhs.slice(lhs_l.start_offset()..);
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
let out = unsafe { dev.alloc::<f64>(elem_count) }?;
@@ -661,6 +689,8 @@ impl CudaStorage {
CudaStorageSlice::F64(out)
}
(CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => {
+ let lhs = &lhs.slice(lhs_l.start_offset()..);
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?;
let out = unsafe { dev.alloc::<u32>(elem_count) }?;
@@ -708,28 +738,32 @@ impl CudaStorage {
pub(crate) fn where_cond(
&self,
- shape: &Shape,
- stride: &[usize],
+ layout: &Layout,
t: &Self,
- stride_t: &[usize],
+ layout_t: &Layout,
f: &Self,
- stride_f: &[usize],
+ layout_f: &Layout,
) -> Result<Self> {
let ids = match &self.slice {
- CudaStorageSlice::U32(slice) => slice,
+ CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
+ let ids = &ids;
+ let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
- let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
+ let ds =
+ dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?;
let slice = match (&t.slice, &f.slice) {
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
+ let t = &t.slice(layout_t.start_offset()..);
+ let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el) }?;
@@ -739,6 +773,8 @@ impl CudaStorage {
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
+ let t = &t.slice(layout_t.start_offset()..);
+ let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el) }?;
@@ -748,6 +784,8 @@ impl CudaStorage {
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
+ let t = &t.slice(layout_t.start_offset()..);
+ let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el) }?;
@@ -757,6 +795,8 @@ impl CudaStorage {
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
+ let t = &t.slice(layout_t.start_offset()..);
+ let f = &f.slice(layout_f.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<f64>(el) }?;
@@ -766,6 +806,8 @@ impl CudaStorage {
CudaStorageSlice::F64(out)
}
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
+ let t = &t.slice(layout_t.start_offset()..);
+ let f = &f.slice(layout_f.start_offset()..);
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<u32>(el) }?;
@@ -775,36 +817,36 @@ impl CudaStorage {
CudaStorageSlice::U32(out)
}
// The dtypes should have been checked at this point so this is an internal error.
- _ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
+ _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
let device = dev.clone();
Ok(Self { slice, device })
}
- pub(crate) fn embedding_impl(
- &self,
- shape: &Shape,
- stride: &[usize],
- rhs: &Self,
- h_size: usize, // hidden size
- v_size: usize, // vocab size
- ) -> Result<Self> {
+ pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let ids = match &self.slice {
- CudaStorageSlice::U32(slice) => slice,
+ CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
+ let ids = &ids;
+ let shape = layout.shape();
+ let (v_size, h_size) = rhs_l
+ .shape()
+ .r2()
+ .map_err(|e| CudaError::WrappedError(Box::new(e)))?;
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
- let ds = dev.htod_copy([dims, stride].concat())?;
+ let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let slice = match &rhs.slice {
// The kernels below assume that rhs is contiguous.
CudaStorageSlice::U32(arg) => {
+ let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
@@ -814,6 +856,7 @@ impl CudaStorage {
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
+ let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
@@ -823,6 +866,7 @@ impl CudaStorage {
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
+ let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
@@ -832,6 +876,7 @@ impl CudaStorage {
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
+ let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
@@ -841,6 +886,7 @@ impl CudaStorage {
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
+ let arg = &arg.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
@@ -854,12 +900,12 @@ impl CudaStorage {
Ok(Self { slice, device })
}
- pub(crate) fn matmul_impl(
+ pub(crate) fn matmul(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
- lhs_stride: &[usize],
- rhs_stride: &[usize],
+ lhs_l: &Layout,
+ rhs_l: &Layout,
) -> Result<Self> {
let elem_count = b * m * n;
let dev = &self.device;
@@ -868,7 +914,9 @@ impl CudaStorage {
todo!("bf16")
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
- let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?;
+ let lhs = &lhs.slice(lhs_l.start_offset()..);
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
+ let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
unsafe {
self.device
@@ -878,7 +926,9 @@ impl CudaStorage {
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
- let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
+ let lhs = &lhs.slice(lhs_l.start_offset()..);
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
+ let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
unsafe {
self.device
@@ -888,7 +938,9 @@ impl CudaStorage {
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
- let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
+ let lhs = &lhs.slice(lhs_l.start_offset()..);
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
+ let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
unsafe {
self.device
@@ -907,22 +959,18 @@ impl CudaStorage {
&self,
dst: &mut Self,
dst_offset: usize,
- src_shape: &Shape,
- src_stride: &[usize],
- src_offset: usize,
+ src_l: &Layout,
) -> Result<()> {
- if src_shape.rank() != src_stride.len() {
- panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
- }
+ let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
- let ds = dev.htod_copy([dims, src_stride].concat())?;
+ let ds = dev.htod_copy([dims, src_l.stride()].concat())?;
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
- let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
- if src_shape.is_contiguous(src_stride) {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
@@ -933,8 +981,8 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
- let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
- if src_shape.is_contiguous(src_stride) {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
@@ -945,8 +993,8 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
- let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
- if src_shape.is_contiguous(src_stride) {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
@@ -957,8 +1005,8 @@ impl CudaStorage {
}
}
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
- let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
- if src_shape.is_contiguous(src_stride) {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
@@ -969,8 +1017,8 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
- let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
- if src_shape.is_contiguous(src_stride) {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;