diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-06-21 21:37:54 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-21 21:37:54 +0100 |
commit | db35b310504ab97044b2c3826de72f9bccf86415 (patch) | |
tree | 710596156a4c026d4dd2ba804fab79b6cdafae3b /src | |
parent | 7c317f9611c263f10d661b44151d3655a2fa3b90 (diff) | |
parent | 7c46de9584fd4315b84d3bc4c28cf1b2bad7785d (diff) | |
download | candle-db35b310504ab97044b2c3826de72f9bccf86415.tar.gz candle-db35b310504ab97044b2c3826de72f9bccf86415.tar.bz2 candle-db35b310504ab97044b2c3826de72f9bccf86415.zip |
Merge pull request #3 from LaurentMazare/cuda
Add Cuda support.
Diffstat (limited to 'src')
-rw-r--r-- | src/cuda_backend.rs | 134 | ||||
-rw-r--r-- | src/device.rs | 55 | ||||
-rw-r--r-- | src/dummy_cuda_backend.rs | 52 | ||||
-rw-r--r-- | src/error.rs | 9 | ||||
-rw-r--r-- | src/lib.rs | 11 | ||||
-rw-r--r-- | src/shape.rs | 14 | ||||
-rw-r--r-- | src/storage.rs | 21 | ||||
-rw-r--r-- | src/tensor.rs | 38 |
8 files changed, 286 insertions, 48 deletions
diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs new file mode 100644 index 00000000..d12db972 --- /dev/null +++ b/src/cuda_backend.rs @@ -0,0 +1,134 @@ +use crate::{CpuStorage, DType, Result, Shape}; +use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; + +pub type CudaError = cudarc::driver::DriverError; + +#[derive(Debug, Clone)] +pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>); + +// TODO: Switch to pre-compiled PTX kernels rather than compiling on the fly. +const AFFINE_CU: &str = r#" +extern "C" __global__ void affine_f32( + const size_t numel, + const float *x, + float *y, + const float mul, + const float add +) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= numel) { + return; + } + y[i] = x[i] * mul + add; +} +"#; + +impl CudaDevice { + pub(crate) fn new(ordinal: usize) -> Result<Self> { + let device = cudarc::driver::CudaDevice::new(ordinal)?; + Ok(Self(device)) + } + + pub(crate) fn ordinal(&self) -> usize { + self.0.ordinal() + } + + pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> { + let elem_count = shape.elem_count(); + match dtype { + DType::F32 => { + let data = self.0.alloc_zeros::<f32>(elem_count)?; + Ok(CudaStorage::F32(data)) + } + DType::F64 => { + let data = self.0.alloc_zeros::<f64>(elem_count)?; + Ok(CudaStorage::F64(data)) + } + } + } + + pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> { + match storage { + CpuStorage::F32(storage) => { + let data = self.0.htod_sync_copy(storage)?; + Ok(CudaStorage::F32(data)) + } + CpuStorage::F64(storage) => { + let data = self.0.htod_sync_copy(storage)?; + Ok(CudaStorage::F64(data)) + } + } + } +} + +#[derive(Debug, Clone)] +pub enum CudaStorage { + F32(CudaSlice<f32>), + F64(CudaSlice<f64>), +} + +impl CudaStorage { + pub fn dtype(&self) -> DType { + match self { + Self::F32(_) => DType::F32, + Self::F64(_) => DType::F64, + } + } + + pub fn device(&self) -> CudaDevice { + match self { + Self::F32(slice) => CudaDevice(slice.device()), + Self::F64(slice) => CudaDevice(slice.device()), + } + } + + pub(crate) fn affine_impl( + &self, + shape: &Shape, + stride: &[usize], + mul: f64, + add: f64, + ) -> Result<Self> { + match self { + Self::F32(arg) => { + if !shape.is_contiguous(stride) { + todo!("affine is only implemented for the contiguous case") + } + let dev = arg.device(); + let module_name = "affine_f32"; + if !dev.has_func(module_name, module_name) { + let ptx = cudarc::nvrtc::compile_ptx(AFFINE_CU).unwrap(); + dev.load_ptx(ptx, module_name, &[module_name])?; + } + let elem_count = shape.elem_count(); + let fwd_fn = dev.get_func(module_name, module_name).unwrap(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + // SAFETY: if this function returns Ok(..), the kernel has been applied + // and has set the initially unset memory. + let out = unsafe { dev.alloc::<f32>(elem_count) }?; + let params = (elem_count, arg, &out, mul as f32, add as f32); + // SAFETY: well, well, well... + unsafe { fwd_fn.launch(cfg, params) }?; + Ok(Self::F32(out)) + } + Self::F64(_) => { + todo!() + } + } + } + + pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> { + match self { + Self::F32(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice)?; + Ok(CpuStorage::F32(cpu_storage)) + } + Self::F64(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice)?; + Ok(CpuStorage::F64(cpu_storage)) + } + } + } +} diff --git a/src/device.rs b/src/device.rs index c76cc301..e522cd42 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,11 +1,19 @@ use crate::{CpuStorage, DType, Result, Shape, Storage}; +/// A `DeviceLocation` represents a physical device whereas multiple `Device` +/// can live on the same location (typically for cuda devices). #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum Device { +pub enum DeviceLocation { Cpu, Cuda { gpu_id: usize }, } +#[derive(Debug, Clone)] +pub enum Device { + Cpu, + Cuda(crate::CudaDevice), +} + // TODO: Should we back the cpu implementation using the NdArray crate or similar? pub trait NdArray { fn shape(&self) -> Result<Shape>; @@ -54,14 +62,31 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; } impl Device { + pub fn new_cuda(ordinal: usize) -> Result<Self> { + Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) + } + + pub fn location(&self) -> DeviceLocation { + match self { + Self::Cpu => DeviceLocation::Cpu, + Self::Cuda(device) => DeviceLocation::Cuda { + gpu_id: device.ordinal(), + }, + } + } + pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> { match self { Device::Cpu => { - let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype)); - Ok(storage) + let storage = CpuStorage::ones_impl(shape, dtype); + Ok(Storage::Cpu(storage)) } - Device::Cuda { gpu_id: _ } => { - todo!() + Device::Cuda(device) => { + // TODO: Instead of allocating memory on the host and transfering it, + // allocate some zeros on the device and use a shader to set them to 1. + let storage = CpuStorage::ones_impl(shape, dtype); + let storage = device.cuda_from_cpu_storage(&storage)?; + Ok(Storage::Cuda(storage)) } } } @@ -69,23 +94,23 @@ impl Device { pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> { match self { Device::Cpu => { - let storage = Storage::Cpu(CpuStorage::zeros_impl(shape, dtype)); - Ok(storage) + let storage = CpuStorage::zeros_impl(shape, dtype); + Ok(Storage::Cpu(storage)) } - Device::Cuda { gpu_id: _ } => { - todo!() + Device::Cuda(device) => { + let storage = device.zeros_impl(shape, dtype)?; + Ok(Storage::Cuda(storage)) } } } pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> { match self { - Device::Cpu => { - let storage = Storage::Cpu(array.to_cpu_storage()); - Ok(storage) - } - Device::Cuda { gpu_id: _ } => { - todo!() + Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), + Device::Cuda(device) => { + let storage = array.to_cpu_storage(); + let storage = device.cuda_from_cpu_storage(&storage)?; + Ok(Storage::Cuda(storage)) } } } diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs new file mode 100644 index 00000000..f555327f --- /dev/null +++ b/src/dummy_cuda_backend.rs @@ -0,0 +1,52 @@ +#![allow(dead_code)] +use crate::{CpuStorage, DType, Result, Shape}; + +pub type CudaError = std::io::Error; + +#[derive(Debug, Clone)] +pub struct CudaDevice; + +macro_rules! fail { + () => { + unimplemented!("cuda support has not been enabled") + }; +} + +impl CudaDevice { + pub(crate) fn new(_: usize) -> Result<Self> { + fail!() + } + + pub(crate) fn ordinal(&self) -> usize { + fail!() + } + + pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> { + fail!() + } + + pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> { + fail!() + } +} + +#[derive(Debug, Clone)] +pub struct CudaStorage; + +impl CudaStorage { + pub fn dtype(&self) -> DType { + fail!() + } + + pub fn device(&self) -> CudaDevice { + fail!() + } + + pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> { + fail!() + } + + pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> { + fail!() + } +} diff --git a/src/error.rs b/src/error.rs index 0114a86c..3f142960 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -use crate::{DType, Device, Shape}; +use crate::{DType, DeviceLocation, Shape}; /// Main library error type. #[derive(thiserror::Error, Debug)] @@ -15,8 +15,8 @@ pub enum Error { #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] DeviceMismatchBinaryOp { - lhs: Device, - rhs: Device, + lhs: DeviceLocation, + rhs: DeviceLocation, op: &'static str, }, @@ -33,6 +33,9 @@ pub enum Error { got: usize, shape: Shape, }, + + #[error(transparent)] + Cuda(#[from] crate::CudaError), } pub type Result<T> = std::result::Result<T, Error>; @@ -1,6 +1,9 @@ mod cpu_backend; +#[cfg(feature = "cuda")] +mod cuda_backend; mod device; mod dtype; +mod dummy_cuda_backend; mod error; mod op; mod shape; @@ -9,10 +12,16 @@ mod strided_index; mod tensor; pub use cpu_backend::CpuStorage; -pub use device::Device; +pub use device::{Device, DeviceLocation}; pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; pub use shape::Shape; pub use storage::Storage; use strided_index::StridedIndex; pub use tensor::{Tensor, TensorId}; + +#[cfg(feature = "cuda")] +pub use cuda_backend::{CudaDevice, CudaError, CudaStorage}; + +#[cfg(not(feature = "cuda"))] +pub use dummy_cuda_backend::{CudaDevice, CudaError, CudaStorage}; diff --git a/src/shape.rs b/src/shape.rs index d626aee6..ebc497cf 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -128,6 +128,20 @@ impl Shape { stride.reverse(); stride } + + pub fn is_contiguous(&self, stride: &[usize]) -> bool { + if self.0.len() != stride.len() { + return false; + } + let mut acc = 1; + for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() { + if stride != acc { + return false; + } + acc *= dim; + } + true + } } #[cfg(test)] diff --git a/src/storage.rs b/src/storage.rs index 7083cc28..573cf945 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,9 +1,9 @@ -use crate::{CpuStorage, DType, Device, Error, Result, Shape}; +use crate::{CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; #[derive(Debug, Clone)] pub enum Storage { Cpu(CpuStorage), - Cuda { gpu_id: usize }, // TODO: Actually add the storage. + Cuda(CudaStorage), } pub(crate) trait UnaryOp { @@ -100,20 +100,20 @@ impl Storage { pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, - Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id }, + Self::Cuda(storage) => Device::Cuda(storage.device()), } } pub fn dtype(&self) -> DType { match self { Self::Cpu(storage) => storage.dtype(), - Self::Cuda { .. } => todo!(), + Self::Cuda(storage) => storage.dtype(), } } pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { - let lhs = self.device(); - let rhs = rhs.device(); + let lhs = self.device().location(); + let rhs = rhs.device().location(); if lhs != rhs { Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }) } else { @@ -144,7 +144,10 @@ impl Storage { let storage = storage.affine_impl(shape, stride, mul, add)?; Ok(Self::Cpu(storage)) } - Self::Cuda { .. } => todo!(), + Self::Cuda(storage) => { + let storage = storage.affine_impl(shape, stride, mul, add)?; + Ok(Self::Cuda(storage)) + } } } @@ -179,8 +182,8 @@ impl Storage { // Should not happen because of the same device check above but we're defensive // anyway. Err(Error::DeviceMismatchBinaryOp { - lhs: lhs.device(), - rhs: rhs.device(), + lhs: lhs.device().location(), + rhs: rhs.device().location(), op: B::NAME, }) } diff --git a/src/tensor.rs b/src/tensor.rs index 9ba412f9..02105573 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -84,7 +84,7 @@ impl Tensor { fn ones_impl<S: Into<Shape>>( shape: S, dtype: DType, - device: Device, + device: &Device, is_variable: bool, ) -> Result<Self> { let shape = shape.into(); @@ -101,22 +101,22 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } - pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { + pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::ones_impl(shape, dtype, device, false) } - pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { + pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::ones_impl(shape, dtype, device, true) } pub fn ones_like(&self) -> Result<Self> { - Tensor::ones(self.shape(), self.dtype(), self.device()) + Tensor::ones(self.shape(), self.dtype(), &self.device()) } fn zeros_impl<S: Into<Shape>>( shape: S, dtype: DType, - device: Device, + device: &Device, is_variable: bool, ) -> Result<Self> { let shape = shape.into(); @@ -133,21 +133,21 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } - pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { + pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::zeros_impl(shape, dtype, device, false) } - pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { + pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::zeros_impl(shape, dtype, device, true) } pub fn zeros_like(&self) -> Result<Self> { - Tensor::zeros(self.shape(), self.dtype(), self.device()) + Tensor::zeros(self.shape(), self.dtype(), &self.device()) } pub fn new_impl<A: crate::device::NdArray>( array: A, - device: Device, + device: &Device, is_variable: bool, ) -> Result<Self> { let shape = array.shape()?; @@ -164,11 +164,11 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } - pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> { + pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { Self::new_impl(array, device, false) } - pub fn var<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> { + pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { Self::new_impl(array, device, true) } @@ -250,7 +250,12 @@ impl Tensor { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok(self.strided_index().map(|i| data[i]).collect()) } - Storage::Cuda { .. } => todo!(), + Storage::Cuda(slice) => { + // TODO: Would it be possible to only fetch the necessary data? + let cpu_storage = slice.to_cpu_storage()?; + let data = S::cpu_storage_as_slice(&cpu_storage)?; + Ok(self.strided_index().map(|i| data[i]).collect()) + } } } @@ -305,14 +310,7 @@ impl Tensor { } pub fn is_contiguous(&self) -> bool { - let mut acc = 1; - for (&stride, &dim) in self.stride.iter().zip(self.shape.dims().iter()).rev() { - if stride != acc { - return false; - } - acc *= dim; - } - true + self.shape.is_contiguous(&self.stride) } /// Return all the nodes that lead to this value in a topologically sorted vec, the first |