summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backend.rs71
-rw-r--r--candle-core/src/conv.rs2
-rw-r--r--candle-core/src/cuda_backend.rs515
-rw-r--r--candle-core/src/device.rs13
-rw-r--r--candle-core/src/dummy_cuda_backend.rs118
-rw-r--r--candle-core/src/error.rs2
-rw-r--r--candle-core/src/lib.rs5
-rw-r--r--candle-core/src/storage.rs1
-rw-r--r--candle-core/src/tensor.rs7
9 files changed, 409 insertions, 325 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
new file mode 100644
index 00000000..aa35703d
--- /dev/null
+++ b/candle-core/src/backend.rs
@@ -0,0 +1,71 @@
+use crate::{CpuStorage, DType, Layout, Result, Shape};
+
+pub(crate) trait BackendStorage: Sized {
+ type Device: BackendDevice;
+
+ fn try_clone(&self, _: &Layout) -> Result<Self>;
+
+ fn dtype(&self) -> DType;
+
+ fn device(&self) -> &Self::Device;
+
+ fn to_cpu_storage(&self) -> Result<CpuStorage>;
+
+ fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
+
+ fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
+
+ fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self>;
+
+ fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
+
+ fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
+
+ fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self>;
+
+ fn binary_impl<B: crate::op::BinaryOp>(&self, _: &Self, _: &Layout, _: &Layout)
+ -> Result<Self>;
+
+ fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
+
+ fn conv1d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConv1D,
+ ) -> Result<Self>;
+
+ fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
+
+ fn matmul(
+ &self,
+ _: &Self,
+ _: (usize, usize, usize, usize),
+ _: &Layout,
+ _: &Layout,
+ ) -> Result<Self>;
+
+ fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
+}
+
+pub(crate) trait BackendDevice: Sized + std::fmt::Debug + Clone {
+ type Storage: BackendStorage;
+
+ // TODO: Make the usize generic and part of a generic DeviceLocation.
+ fn new(_: usize) -> Result<Self>;
+
+ fn location(&self) -> crate::DeviceLocation;
+
+ fn same_device(&self, _: &Self) -> bool;
+
+ fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
+
+ fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
+
+ fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
+
+ fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
+
+ fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
+}
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 041bb6fb..4cf9d0ad 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -1,5 +1,5 @@
#[derive(Debug, Clone, PartialEq, Eq)]
-pub(crate) struct ParamsConv1D {
+pub struct ParamsConv1D {
pub(crate) b_size: Option<usize>,
// Maybe we should have a version without l_in as this bit depends on the input and not only on
// the weights.
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index c2d01a07..73bc9e34 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1,4 +1,5 @@
-use crate::{CpuStorage, DType, Layout, Shape, WithDType};
+use crate::backend::{BackendDevice, BackendStorage};
+use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{
@@ -22,9 +23,6 @@ pub enum CudaError {
#[error(transparent)]
Curand(#[from] cudarc::curand::result::CurandError),
- #[error("{op} only supports contiguous tensors")]
- RequiresContiguous { op: &'static str },
-
#[error("missing kernel '{module_name}'")]
MissingKernel { module_name: String },
@@ -58,7 +56,11 @@ pub enum CudaError {
},
}
-type Result<T> = std::result::Result<T, CudaError>;
+impl From<CudaError> for crate::Error {
+ fn from(val: CudaError) -> Self {
+ crate::Error::Cuda(Box::new(val))
+ }
+}
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
@@ -98,11 +100,105 @@ impl std::ops::Deref for CudaDevice {
}
}
+trait WrapErr<O> {
+ fn w(self) -> std::result::Result<O, crate::Error>;
+}
+
+impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
+ fn w(self) -> std::result::Result<O, crate::Error> {
+ self.map_err(|e| crate::Error::Cuda(Box::new(e.into())))
+ }
+}
+
impl CudaDevice {
- pub(crate) fn new(ordinal: usize) -> Result<Self> {
- let device = cudarc::driver::CudaDevice::new(ordinal)?;
- let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
- let curand = cudarc::curand::CudaRng::new(299792458, device.clone())?;
+ fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
+ let elem_count = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ let slice = match dtype {
+ DType::U8 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
+ let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
+ let params = (&data, v as u8, elem_count);
+ unsafe { func.launch(cfg, params) }.w()?;
+ CudaStorageSlice::U8(data)
+ }
+ DType::U32 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
+ let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
+ let params = (&data, v as u32, elem_count);
+ unsafe { func.launch(cfg, params) }.w()?;
+ CudaStorageSlice::U32(data)
+ }
+ DType::BF16 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
+ let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
+ let params = (&data, bf16::from_f64(v), elem_count);
+ unsafe { func.launch(cfg, params) }.w()?;
+ CudaStorageSlice::BF16(data)
+ }
+ DType::F16 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
+ let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
+ let params = (&data, f16::from_f64(v), elem_count);
+ unsafe { func.launch(cfg, params) }.w()?;
+ CudaStorageSlice::F16(data)
+ }
+ DType::F32 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
+ let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
+ let params = (&data, v as f32, elem_count);
+ unsafe { func.launch(cfg, params) }.w()?;
+ CudaStorageSlice::F32(data)
+ }
+ DType::F64 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
+ let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
+ let params = (&data, v, elem_count);
+ unsafe { func.launch(cfg, params) }.w()?;
+ CudaStorageSlice::F64(data)
+ }
+ };
+ Ok(CudaStorage {
+ slice,
+ device: self.clone(),
+ })
+ }
+
+ fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
+ if !self.has_func(module_name, module_name) {
+ // Leaking the string here is a bit sad but we need a &'static str and this is only
+ // done once per kernel name.
+ let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
+ self.load_ptx(ptx.into(), module_name, &[static_module_name])
+ .map_err(|cuda| CudaError::Load {
+ cuda,
+ module_name: module_name.to_string(),
+ })
+ .w()?;
+ }
+ self.get_func(module_name, module_name)
+ // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
+ // able to only build the error value if needed.
+ .ok_or(CudaError::MissingKernel {
+ module_name: module_name.to_string(),
+ })
+ .w()
+ }
+}
+
+impl BackendDevice for CudaDevice {
+ type Storage = CudaStorage;
+
+ fn new(ordinal: usize) -> Result<Self> {
+ let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
+ let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
+ let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
device,
@@ -111,39 +207,41 @@ impl CudaDevice {
})
}
- pub(crate) fn same_id(&self, rhs: &Self) -> bool {
- self.id == rhs.id
+ fn location(&self) -> crate::DeviceLocation {
+ crate::DeviceLocation::Cuda {
+ gpu_id: self.device.ordinal(),
+ }
}
- pub(crate) fn ordinal(&self) -> usize {
- self.device.ordinal()
+ fn same_device(&self, rhs: &Self) -> bool {
+ self.id == rhs.id
}
- pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
+ fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
- let data = self.alloc_zeros::<u8>(elem_count)?;
+ let data = self.alloc_zeros::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
- let data = self.alloc_zeros::<u32>(elem_count)?;
+ let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::BF16 => {
- let data = self.alloc_zeros::<bf16>(elem_count)?;
+ let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
- let data = self.alloc_zeros::<f16>(elem_count)?;
+ let data = self.alloc_zeros::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
- let data = self.alloc_zeros::<f32>(elem_count)?;
+ let data = self.alloc_zeros::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
- let data = self.alloc_zeros::<f64>(elem_count)?;
+ let data = self.alloc_zeros::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
@@ -153,30 +251,23 @@ impl CudaDevice {
})
}
- pub(crate) fn rand_uniform(
- &self,
- shape: &Shape,
- dtype: DType,
- lo: f64,
- up: f64,
- ) -> Result<CudaStorage> {
+ fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
- DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
- Err(CudaError::UnsupportedDtype {
- dtype,
- op: "rand_uniform",
- })?
- }
+ DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
+ dtype,
+ op: "rand_uniform",
+ })
+ .w()?,
DType::F32 => {
- let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
- curand.0.fill_with_uniform(&mut data)?;
+ let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
+ curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
- let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
- curand.0.fill_with_uniform(&mut data)?;
+ let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
+ curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
};
@@ -190,91 +281,26 @@ impl CudaDevice {
})
}
- pub(crate) fn rand_normal(
- &self,
- shape: &Shape,
- dtype: DType,
- mean: f64,
- std: f64,
- ) -> Result<CudaStorage> {
+ fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
- DType::U8 | DType::U32 | DType::F16 | DType::BF16 => {
- Err(CudaError::UnsupportedDtype {
- dtype,
- op: "rand_normal",
- })?
- }
+ DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
+ dtype,
+ op: "rand_normal",
+ })
+ .w()?,
DType::F32 => {
- let mut data = unsafe { self.alloc::<f32>(elem_count) }?;
+ let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand
.0
- .fill_with_normal(&mut data, mean as f32, std as f32)?;
- CudaStorageSlice::F32(data)
- }
- DType::F64 => {
- let mut data = unsafe { self.alloc::<f64>(elem_count) }?;
- curand.0.fill_with_normal(&mut data, mean, std)?;
- CudaStorageSlice::F64(data)
- }
- };
- Ok(CudaStorage {
- slice,
- device: self.clone(),
- })
- }
-
- pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
- let elem_count = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(elem_count as u32);
- let slice = match dtype {
- DType::U8 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::<u8>(elem_count) }?;
- let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
- let params = (&data, v as u8, elem_count);
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::U8(data)
- }
- DType::U32 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::<u32>(elem_count) }?;
- let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
- let params = (&data, v as u32, elem_count);
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::U32(data)
- }
- DType::BF16 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::<bf16>(elem_count) }?;
- let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
- let params = (&data, bf16::from_f64(v), elem_count);
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::BF16(data)
- }
- DType::F16 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::<f16>(elem_count) }?;
- let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
- let params = (&data, f16::from_f64(v), elem_count);
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F16(data)
- }
- DType::F32 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::<f32>(elem_count) }?;
- let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
- let params = (&data, v as f32, elem_count);
- unsafe { func.launch(cfg, params) }?;
+ .fill_with_normal(&mut data, mean as f32, std as f32)
+ .w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
- // SAFETY: Set later by running the fill kernel.
- let data = unsafe { self.alloc::<f64>(elem_count) }?;
- let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
- let params = (&data, v, elem_count);
- unsafe { func.launch(cfg, params) }?;
+ let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
+ curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
};
@@ -284,34 +310,34 @@ impl CudaDevice {
})
}
- pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
+ fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
- pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
+ fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
- let data = self.htod_sync_copy(storage)?;
+ let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
- let data = self.htod_sync_copy(storage)?;
+ let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::BF16(storage) => {
- let data = self.htod_sync_copy(storage)?;
+ let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
- let data = self.htod_sync_copy(storage)?;
+ let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
- let data = self.htod_sync_copy(storage)?;
+ let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
- let data = self.htod_sync_copy(storage)?;
+ let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
@@ -320,25 +346,6 @@ impl CudaDevice {
device: self.clone(),
})
}
-
- fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
- if !self.has_func(module_name, module_name) {
- // Leaking the string here is a bit sad but we need a &'static str and this is only
- // done once per kernel name.
- let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
- self.load_ptx(ptx.into(), module_name, &[static_module_name])
- .map_err(|cuda| CudaError::Load {
- cuda,
- module_name: module_name.to_string(),
- })?;
- }
- self.get_func(module_name, module_name)
- // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
- // able to only build the error value if needed.
- .ok_or(CudaError::MissingKernel {
- module_name: module_name.to_string(),
- })
- }
}
#[derive(Debug)]
@@ -391,7 +398,7 @@ trait Map2 {
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
- _ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
+ _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
};
Ok(out)
}
@@ -405,7 +412,7 @@ impl Map1 for Clone {
_: &CudaDevice,
_: &Layout,
) -> Result<CudaSlice<T>> {
- Ok(s.try_clone()?)
+ s.try_clone().w()
}
}
@@ -426,11 +433,11 @@ impl Map1 for Affine {
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
- let ds = dev.htod_copy([dims, layout.stride()].concat())?;
+ let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<T>(el) }?;
+ let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
el,
dims.len(),
@@ -441,7 +448,7 @@ impl Map1 for Affine {
T::from_f64(self.1),
);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
@@ -458,14 +465,14 @@ impl Map1 for Elu {
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
- let ds = dev.htod_copy([dims, layout.stride()].concat())?;
+ let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<T>(el) }?;
+ let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
@@ -495,13 +502,15 @@ impl<'a> Map1 for Sum<'a> {
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
.collect();
let cfg = LaunchConfig::for_num_elems(el as u32);
- let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
+ let ds = dev
+ .htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())
+ .w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
- let out = dev.alloc_zeros::<T>(dst_el)?;
+ let out = dev.alloc_zeros::<T>(dst_el).w()?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
@@ -544,13 +553,15 @@ impl<'a> Map1 for FastSum<'a> {
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
- let ds = dev.htod_copy([dims.as_slice(), stride.as_slice()].concat())?;
+ let ds = dev
+ .htod_copy([dims.as_slice(), stride.as_slice()].concat())
+ .w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("fast_sum"), kernels::REDUCE)?;
- let out = dev.alloc_zeros::<T>(dst_el)?;
+ let out = dev.alloc_zeros::<T>(dst_el).w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
@@ -566,14 +577,14 @@ impl<U: crate::op::UnaryOp> Map1 for U {
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
- let ds = dev.htod_copy([dims, layout.stride()].concat())?;
+ let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<T>(el_count) }?;
+ let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
let params = (el_count, dims.len(), &ds, src, &out);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
@@ -593,25 +604,27 @@ impl<'a> Map1 for Embedding<'a> {
msg: "embedding ids should be u32",
expected: DType::U32,
got: self.0.dtype(),
- })?,
+ })
+ .w()?,
};
let ids = &ids;
let shape = ids_l.shape();
let (v_size, h_size) = rhs_l
.shape()
.r2()
- .map_err(|e| CudaError::WrappedError(Box::new(e)))?;
+ .map_err(|e| CudaError::WrappedError(Box::new(e)))
+ .w()?;
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
- let ds = dev.htod_copy([dims, ids_l.stride()].concat())?;
+ let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<T>(el * h_size) }?;
+ let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
@@ -640,7 +653,7 @@ impl<'a> Map2 for Conv1D<'a> {
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<T>(dst_el) }?;
+ let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let ds = if dims.len() == 3 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else if dims.len() == 2 {
@@ -648,10 +661,10 @@ impl<'a> Map2 for Conv1D<'a> {
} else {
panic!("unexpected input shape for conv1d {dims:?}")
};
- let ds = dev.htod_copy(ds)?;
+ let ds = dev.htod_copy(ds).w()?;
let params = (el, l_out, p.stride, &ds, inp, k, &out);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
@@ -673,23 +686,25 @@ impl<'a> Map2 for WhereCond<'a> {
msg: "where conditions should be u32",
expected: DType::U32,
got: self.0.dtype(),
- })?,
+ })
+ .w()?,
};
let ids = &ids;
let shape = ids_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
- let ds =
- dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
+ let ds = dev
+ .htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())
+ .w()?;
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<T>(el) }?;
+ let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
@@ -707,15 +722,17 @@ impl<U: crate::op::BinaryOp> Map2 for U {
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
- let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
+ let dims_and_strides = dev
+ .htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
+ .w()?;
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<T>(elem_count) }?;
+ let out = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
@@ -771,7 +788,8 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
- })?
+ })
+ .w()?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
@@ -783,7 +801,8 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
- })?
+ })
+ .w()?
};
// The setup below was copied from:
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
@@ -808,7 +827,8 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
- })?,
+ })
+ .w()?,
};
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
@@ -818,7 +838,8 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
- })?,
+ })
+ .w()?,
};
Ok(StridedBatchedConfig {
@@ -830,14 +851,16 @@ fn gemm_config<T>(
})
}
-impl CudaStorage {
- pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
+impl BackendStorage for CudaStorage {
+ type Device = CudaDevice;
+
+ fn try_clone(&self, layout: &Layout) -> Result<Self> {
let slice = Clone.map(&self.slice, self.device(), layout)?;
let device = self.device.clone();
Ok(Self { slice, device })
}
- pub fn dtype(&self) -> DType {
+ fn dtype(&self) -> DType {
match self.slice {
CudaStorageSlice::U8(_) => DType::U8,
CudaStorageSlice::U32(_) => DType::U32,
@@ -848,18 +871,18 @@ impl CudaStorage {
}
}
- pub fn device(&self) -> &CudaDevice {
+ fn device(&self) -> &CudaDevice {
&self.device
}
- pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
+ fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
use cudarc::driver::DevicePtr;
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
- let ds = dev.htod_copy([dims, layout.stride()].concat())?;
+ let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let start_o = layout.start_offset();
// This returns an i64 rather than a &i64, this is useful to get around some temporary
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
@@ -878,39 +901,39 @@ impl CudaStorage {
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
let slice = match dtype {
DType::U8 => {
- let out = unsafe { dev.alloc::<u8>(el) }?;
+ let out = unsafe { dev.alloc::<u8>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U8(out)
}
DType::U32 => {
- let out = unsafe { dev.alloc::<u32>(el) }?;
+ let out = unsafe { dev.alloc::<u32>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(out)
}
DType::BF16 => {
- let out = unsafe { dev.alloc::<bf16>(el) }?;
+ let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::BF16(out)
}
DType::F16 => {
- let out = unsafe { dev.alloc::<f16>(el) }?;
+ let out = unsafe { dev.alloc::<f16>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F16(out)
}
DType::F32 => {
- let out = unsafe { dev.alloc::<f32>(el) }?;
+ let out = unsafe { dev.alloc::<f32>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F32(out)
}
DType::F64 => {
- let out = unsafe { dev.alloc::<f64>(el) }?;
+ let out = unsafe { dev.alloc::<f64>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F64(out)
}
};
@@ -920,37 +943,35 @@ impl CudaStorage {
})
}
- pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
+ fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let device = self.device().clone();
let slice = Affine(mul, add).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
- pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
+ fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
let device = self.device().clone();
let slice = Elu(alpha).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
- pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
+ fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device().clone();
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
- pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
- Err(CudaError::InternalError(
- "TODO: implement divide_by_sum_over_dim",
- ))
+ fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
+ Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
}
- pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
+ fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = U::V.map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
- pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
+ fn binary_impl<B: crate::op::BinaryOp>(
&self,
rhs: &Self,
lhs_l: &Layout,
@@ -961,42 +982,42 @@ impl CudaStorage {
Ok(Self { slice, device })
}
- pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
+ fn to_cpu_storage(&self) -> Result<CpuStorage> {
match &self.slice {
CudaStorageSlice::U8(slice) => {
let dev = slice.device();
- let cpu_storage = dev.dtoh_sync_copy(slice)?;
+ let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::U8(cpu_storage))
}
CudaStorageSlice::U32(slice) => {
let dev = slice.device();
- let cpu_storage = dev.dtoh_sync_copy(slice)?;
+ let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::U32(cpu_storage))
}
CudaStorageSlice::BF16(slice) => {
let dev = slice.device();
- let cpu_storage = dev.dtoh_sync_copy(slice)?;
+ let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::BF16(cpu_storage))
}
CudaStorageSlice::F16(slice) => {
let dev = slice.device();
- let cpu_storage = dev.dtoh_sync_copy(slice)?;
+ let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::F16(cpu_storage))
}
CudaStorageSlice::F32(slice) => {
let dev = slice.device();
- let cpu_storage = dev.dtoh_sync_copy(slice)?;
+ let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::F32(cpu_storage))
}
CudaStorageSlice::F64(slice) => {
let dev = slice.device();
- let cpu_storage = dev.dtoh_sync_copy(slice)?;
+ let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::F64(cpu_storage))
}
}
}
- pub(crate) fn where_cond(
+ fn where_cond(
&self,
layout: &Layout,
t: &Self,
@@ -1009,7 +1030,7 @@ impl CudaStorage {
Ok(Self { slice, device })
}
- pub(crate) fn conv1d(
+ fn conv1d(
&self,
l: &Layout,
kernel: &Self,
@@ -1021,13 +1042,13 @@ impl CudaStorage {
Ok(Self { slice, device })
}
- pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
+ fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
Ok(Self { slice, device })
}
- pub(crate) fn matmul(
+ fn matmul(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
@@ -1041,146 +1062,144 @@ impl CudaStorage {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
- let mut out = unsafe { dev.alloc::<bf16>(elem_count) }?;
+ let mut out = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
- }?;
+ }
+ .w()?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
- let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
+ let mut out = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
- }?;
+ }
+ .w()?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
- let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
+ let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
- }?;
+ }
+ .w()?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
- let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
+ let mut out = unsafe { dev.alloc::<f64>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
- }?;
+ }
+ .w()?;
CudaStorageSlice::F64(out)
}
- _ => return Err(CudaError::InternalError("dtype mismatch in matmul op")),
+ _ => Err(CudaError::InternalError("dtype mismatch in matmul op")).w()?,
};
let device = dev.clone();
Ok(Self { slice, device })
}
- pub(crate) fn copy_strided_src(
- &self,
- dst: &mut Self,
- dst_offset: usize,
- src_l: &Layout,
- ) -> Result<()> {
+ fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
- let ds = dev.htod_copy([dims, src_l.stride()].concat())?;
+ let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
- dev.dtod_copy(&src, &mut dst)?
+ dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?
+ unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
- dev.dtod_copy(&src, &mut dst)?
+ dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?
+ unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
- dev.dtod_copy(&src, &mut dst)?
+ dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?
+ unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
- dev.dtod_copy(&src, &mut dst)?
+ dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_u8", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?
+ unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
- dev.dtod_copy(&src, &mut dst)?
+ dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?
+ unsafe { func.launch(cfg, params) }.w()?
}
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
- dev.dtod_copy(&src, &mut dst)?
+ dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
+ unsafe { func.launch(cfg, params) }.w()?;
}
}
- _ => {
- return Err(CudaError::InternalError(
- "dtype mismatch in copy_strided op",
- ))
- }
+ _ => Err(CudaError::InternalError(
+ "dtype mismatch in copy_strided op",
+ ))
+ .w()?,
}
Ok(())
}
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 1380cbc9..b428922b 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -1,3 +1,4 @@
+use crate::backend::BackendDevice;
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
@@ -85,10 +86,10 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
- pub fn same_id(&self, rhs: &Self) -> bool {
+ pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
- (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_id(rhs),
+ (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
_ => false,
}
}
@@ -96,9 +97,7 @@ impl Device {
pub fn location(&self) -> DeviceLocation {
match self {
Self::Cpu => DeviceLocation::Cpu,
- Self::Cuda(device) => DeviceLocation::Cuda {
- gpu_id: device.ordinal(),
- },
+ Self::Cuda(device) => device.location(),
}
}
@@ -178,7 +177,7 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
let storage = array.to_cpu_storage();
- let storage = device.cuda_from_cpu_storage(&storage)?;
+ let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
}
@@ -189,7 +188,7 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
Device::Cuda(device) => {
let storage = S::to_cpu_storage_owned(data);
- let storage = device.cuda_from_cpu_storage(&storage)?;
+ let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
}
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index f5c80fcf..a81dda57 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -1,142 +1,134 @@
#![allow(dead_code)]
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
-#[derive(thiserror::Error, Debug)]
-pub enum DummyError {}
-pub type CudaError = DummyError;
-
#[derive(Debug, Clone)]
pub struct CudaDevice;
+#[derive(Debug)]
+pub struct CudaStorage;
+
macro_rules! fail {
() => {
unimplemented!("cuda support has not been enabled")
};
}
-impl CudaDevice {
- pub(crate) fn new(_: usize) -> Result<Self> {
+impl crate::backend::BackendStorage for CudaStorage {
+ type Device = CudaDevice;
+
+ fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn same_id(&self, _: &Self) -> bool {
- true
+ fn dtype(&self) -> DType {
+ fail!()
}
- pub(crate) fn ordinal(&self) -> usize {
+ fn device(&self) -> &Self::Device {
fail!()
}
- pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
+ fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
+ fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
+ fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<CudaStorage> {
+ fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<CudaStorage> {
+ fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
-}
-#[derive(Debug)]
-pub struct CudaStorage;
-
-impl CudaStorage {
- pub fn try_clone(&self, _: &Layout) -> Result<Self> {
+ fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub fn dtype(&self) -> DType {
- fail!()
+ fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
}
- pub fn device(&self) -> &CudaDevice {
- fail!()
+ fn binary_impl<B: crate::op::BinaryOp>(
+ &self,
+ _: &Self,
+ _: &Layout,
+ _: &Layout,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
+ fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
+ fn conv1d(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &crate::conv::ParamsConv1D,
+ ) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
+ fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
+ fn matmul(
+ &self,
+ _: &Self,
+ _: (usize, usize, usize, usize),
+ _: &Layout,
+ _: &Layout,
+ ) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
+ fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
+}
- pub(crate) fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
+impl crate::backend::BackendDevice for CudaDevice {
+ type Storage = CudaStorage;
+ fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
- Err(Error::NotCompiledWithCudaSupport)
+ fn location(&self) -> crate::DeviceLocation {
+ fail!()
}
- pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
- &self,
- _: &Self,
- _: &Layout,
- _: &Layout,
- ) -> Result<Self> {
- Err(Error::NotCompiledWithCudaSupport)
+ fn same_device(&self, _: &Self) -> bool {
+ fail!()
}
- pub(crate) fn where_cond(
- &self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- ) -> Result<Self> {
+ fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn conv1d(
- &self,
- _l: &Layout,
- _kernel: &Self,
- _kernel_l: &Layout,
- _params: &crate::conv::ParamsConv1D,
- ) -> Result<Self> {
+ fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
+ fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn matmul(
- &self,
- _: &Self,
- _: (usize, usize, usize, usize),
- _: &Layout,
- _: &Layout,
- ) -> Result<Self> {
+ fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
- pub(crate) fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
+ fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
}
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index d8f3b4b4..caad3e1f 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -100,7 +100,7 @@ pub enum Error {
},
#[error(transparent)]
- Cuda(#[from] crate::CudaError),
+ Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index d36f90af..06fc87d1 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -33,6 +33,7 @@
//!
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
+mod backend;
mod backprop;
mod conv;
mod cpu_backend;
@@ -68,10 +69,10 @@ use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};
#[cfg(feature = "cuda")]
-pub use cuda_backend::{CudaDevice, CudaError, CudaStorage};
+pub use cuda_backend::{CudaDevice, CudaStorage};
#[cfg(not(feature = "cuda"))]
-pub use dummy_cuda_backend::{CudaDevice, CudaError, CudaStorage};
+pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index ee12eeb8..5f92172d 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -1,3 +1,4 @@
+use crate::backend::BackendStorage;
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ecc018f9..5d4e106f 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,3 +1,4 @@
+use crate::backend::{BackendDevice, BackendStorage};
use crate::shape::Dim;
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::Arc;
@@ -963,19 +964,19 @@ impl Tensor {
/// If the target device is the same as the tensor device, only a shallow copy is performed.
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
- if self.device().same_id(device) {
+ if self.device().same_device(device) {
Ok(self.clone())
} else {
let storage = match (self.storage.as_ref(), device) {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
- Storage::Cuda(cuda.cuda_from_cpu_storage(storage)?)
+ Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.
let cpu_storage = storage.to_cpu_storage()?;
- Storage::Cuda(cuda.cuda_from_cpu_storage(&cpu_storage)?)
+ Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
};