summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-06-21 21:37:54 +0100
committerGitHub <noreply@github.com>2023-06-21 21:37:54 +0100
commitdb35b310504ab97044b2c3826de72f9bccf86415 (patch)
tree710596156a4c026d4dd2ba804fab79b6cdafae3b /src
parent7c317f9611c263f10d661b44151d3655a2fa3b90 (diff)
parent7c46de9584fd4315b84d3bc4c28cf1b2bad7785d (diff)
downloadcandle-db35b310504ab97044b2c3826de72f9bccf86415.tar.gz
candle-db35b310504ab97044b2c3826de72f9bccf86415.tar.bz2
candle-db35b310504ab97044b2c3826de72f9bccf86415.zip
Merge pull request #3 from LaurentMazare/cuda
Add Cuda support.
Diffstat (limited to 'src')
-rw-r--r--src/cuda_backend.rs134
-rw-r--r--src/device.rs55
-rw-r--r--src/dummy_cuda_backend.rs52
-rw-r--r--src/error.rs9
-rw-r--r--src/lib.rs11
-rw-r--r--src/shape.rs14
-rw-r--r--src/storage.rs21
-rw-r--r--src/tensor.rs38
8 files changed, 286 insertions, 48 deletions
diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs
new file mode 100644
index 00000000..d12db972
--- /dev/null
+++ b/src/cuda_backend.rs
@@ -0,0 +1,134 @@
+use crate::{CpuStorage, DType, Result, Shape};
+use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig};
+
+pub type CudaError = cudarc::driver::DriverError;
+
+#[derive(Debug, Clone)]
+pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>);
+
+// TODO: Switch to pre-compiled PTX kernels rather than compiling on the fly.
+const AFFINE_CU: &str = r#"
+extern "C" __global__ void affine_f32(
+ const size_t numel,
+ const float *x,
+ float *y,
+ const float mul,
+ const float add
+) {
+ unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
+ if (i >= numel) {
+ return;
+ }
+ y[i] = x[i] * mul + add;
+}
+"#;
+
+impl CudaDevice {
+ pub(crate) fn new(ordinal: usize) -> Result<Self> {
+ let device = cudarc::driver::CudaDevice::new(ordinal)?;
+ Ok(Self(device))
+ }
+
+ pub(crate) fn ordinal(&self) -> usize {
+ self.0.ordinal()
+ }
+
+ pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
+ let elem_count = shape.elem_count();
+ match dtype {
+ DType::F32 => {
+ let data = self.0.alloc_zeros::<f32>(elem_count)?;
+ Ok(CudaStorage::F32(data))
+ }
+ DType::F64 => {
+ let data = self.0.alloc_zeros::<f64>(elem_count)?;
+ Ok(CudaStorage::F64(data))
+ }
+ }
+ }
+
+ pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
+ match storage {
+ CpuStorage::F32(storage) => {
+ let data = self.0.htod_sync_copy(storage)?;
+ Ok(CudaStorage::F32(data))
+ }
+ CpuStorage::F64(storage) => {
+ let data = self.0.htod_sync_copy(storage)?;
+ Ok(CudaStorage::F64(data))
+ }
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub enum CudaStorage {
+ F32(CudaSlice<f32>),
+ F64(CudaSlice<f64>),
+}
+
+impl CudaStorage {
+ pub fn dtype(&self) -> DType {
+ match self {
+ Self::F32(_) => DType::F32,
+ Self::F64(_) => DType::F64,
+ }
+ }
+
+ pub fn device(&self) -> CudaDevice {
+ match self {
+ Self::F32(slice) => CudaDevice(slice.device()),
+ Self::F64(slice) => CudaDevice(slice.device()),
+ }
+ }
+
+ pub(crate) fn affine_impl(
+ &self,
+ shape: &Shape,
+ stride: &[usize],
+ mul: f64,
+ add: f64,
+ ) -> Result<Self> {
+ match self {
+ Self::F32(arg) => {
+ if !shape.is_contiguous(stride) {
+ todo!("affine is only implemented for the contiguous case")
+ }
+ let dev = arg.device();
+ let module_name = "affine_f32";
+ if !dev.has_func(module_name, module_name) {
+ let ptx = cudarc::nvrtc::compile_ptx(AFFINE_CU).unwrap();
+ dev.load_ptx(ptx, module_name, &[module_name])?;
+ }
+ let elem_count = shape.elem_count();
+ let fwd_fn = dev.get_func(module_name, module_name).unwrap();
+ let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ // SAFETY: if this function returns Ok(..), the kernel has been applied
+ // and has set the initially unset memory.
+ let out = unsafe { dev.alloc::<f32>(elem_count) }?;
+ let params = (elem_count, arg, &out, mul as f32, add as f32);
+ // SAFETY: well, well, well...
+ unsafe { fwd_fn.launch(cfg, params) }?;
+ Ok(Self::F32(out))
+ }
+ Self::F64(_) => {
+ todo!()
+ }
+ }
+ }
+
+ pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
+ match self {
+ Self::F32(slice) => {
+ let dev = slice.device();
+ let cpu_storage = dev.dtoh_sync_copy(slice)?;
+ Ok(CpuStorage::F32(cpu_storage))
+ }
+ Self::F64(slice) => {
+ let dev = slice.device();
+ let cpu_storage = dev.dtoh_sync_copy(slice)?;
+ Ok(CpuStorage::F64(cpu_storage))
+ }
+ }
+ }
+}
diff --git a/src/device.rs b/src/device.rs
index c76cc301..e522cd42 100644
--- a/src/device.rs
+++ b/src/device.rs
@@ -1,11 +1,19 @@
use crate::{CpuStorage, DType, Result, Shape, Storage};
+/// A `DeviceLocation` represents a physical device whereas multiple `Device`
+/// can live on the same location (typically for cuda devices).
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
-pub enum Device {
+pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
}
+#[derive(Debug, Clone)]
+pub enum Device {
+ Cpu,
+ Cuda(crate::CudaDevice),
+}
+
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
pub trait NdArray {
fn shape(&self) -> Result<Shape>;
@@ -54,14 +62,31 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
}
impl Device {
+ pub fn new_cuda(ordinal: usize) -> Result<Self> {
+ Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
+ }
+
+ pub fn location(&self) -> DeviceLocation {
+ match self {
+ Self::Cpu => DeviceLocation::Cpu,
+ Self::Cuda(device) => DeviceLocation::Cuda {
+ gpu_id: device.ordinal(),
+ },
+ }
+ }
+
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
- let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype));
- Ok(storage)
+ let storage = CpuStorage::ones_impl(shape, dtype);
+ Ok(Storage::Cpu(storage))
}
- Device::Cuda { gpu_id: _ } => {
- todo!()
+ Device::Cuda(device) => {
+ // TODO: Instead of allocating memory on the host and transfering it,
+ // allocate some zeros on the device and use a shader to set them to 1.
+ let storage = CpuStorage::ones_impl(shape, dtype);
+ let storage = device.cuda_from_cpu_storage(&storage)?;
+ Ok(Storage::Cuda(storage))
}
}
}
@@ -69,23 +94,23 @@ impl Device {
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
- let storage = Storage::Cpu(CpuStorage::zeros_impl(shape, dtype));
- Ok(storage)
+ let storage = CpuStorage::zeros_impl(shape, dtype);
+ Ok(Storage::Cpu(storage))
}
- Device::Cuda { gpu_id: _ } => {
- todo!()
+ Device::Cuda(device) => {
+ let storage = device.zeros_impl(shape, dtype)?;
+ Ok(Storage::Cuda(storage))
}
}
}
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
- Device::Cpu => {
- let storage = Storage::Cpu(array.to_cpu_storage());
- Ok(storage)
- }
- Device::Cuda { gpu_id: _ } => {
- todo!()
+ Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
+ Device::Cuda(device) => {
+ let storage = array.to_cpu_storage();
+ let storage = device.cuda_from_cpu_storage(&storage)?;
+ Ok(Storage::Cuda(storage))
}
}
}
diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs
new file mode 100644
index 00000000..f555327f
--- /dev/null
+++ b/src/dummy_cuda_backend.rs
@@ -0,0 +1,52 @@
+#![allow(dead_code)]
+use crate::{CpuStorage, DType, Result, Shape};
+
+pub type CudaError = std::io::Error;
+
+#[derive(Debug, Clone)]
+pub struct CudaDevice;
+
+macro_rules! fail {
+ () => {
+ unimplemented!("cuda support has not been enabled")
+ };
+}
+
+impl CudaDevice {
+ pub(crate) fn new(_: usize) -> Result<Self> {
+ fail!()
+ }
+
+ pub(crate) fn ordinal(&self) -> usize {
+ fail!()
+ }
+
+ pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
+ fail!()
+ }
+
+ pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
+ fail!()
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct CudaStorage;
+
+impl CudaStorage {
+ pub fn dtype(&self) -> DType {
+ fail!()
+ }
+
+ pub fn device(&self) -> CudaDevice {
+ fail!()
+ }
+
+ pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
+ fail!()
+ }
+
+ pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> {
+ fail!()
+ }
+}
diff --git a/src/error.rs b/src/error.rs
index 0114a86c..3f142960 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,4 +1,4 @@
-use crate::{DType, Device, Shape};
+use crate::{DType, DeviceLocation, Shape};
/// Main library error type.
#[derive(thiserror::Error, Debug)]
@@ -15,8 +15,8 @@ pub enum Error {
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
DeviceMismatchBinaryOp {
- lhs: Device,
- rhs: Device,
+ lhs: DeviceLocation,
+ rhs: DeviceLocation,
op: &'static str,
},
@@ -33,6 +33,9 @@ pub enum Error {
got: usize,
shape: Shape,
},
+
+ #[error(transparent)]
+ Cuda(#[from] crate::CudaError),
}
pub type Result<T> = std::result::Result<T, Error>;
diff --git a/src/lib.rs b/src/lib.rs
index 175d36ad..3bae1a7e 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,6 +1,9 @@
mod cpu_backend;
+#[cfg(feature = "cuda")]
+mod cuda_backend;
mod device;
mod dtype;
+mod dummy_cuda_backend;
mod error;
mod op;
mod shape;
@@ -9,10 +12,16 @@ mod strided_index;
mod tensor;
pub use cpu_backend::CpuStorage;
-pub use device::Device;
+pub use device::{Device, DeviceLocation};
pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
pub use shape::Shape;
pub use storage::Storage;
use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};
+
+#[cfg(feature = "cuda")]
+pub use cuda_backend::{CudaDevice, CudaError, CudaStorage};
+
+#[cfg(not(feature = "cuda"))]
+pub use dummy_cuda_backend::{CudaDevice, CudaError, CudaStorage};
diff --git a/src/shape.rs b/src/shape.rs
index d626aee6..ebc497cf 100644
--- a/src/shape.rs
+++ b/src/shape.rs
@@ -128,6 +128,20 @@ impl Shape {
stride.reverse();
stride
}
+
+ pub fn is_contiguous(&self, stride: &[usize]) -> bool {
+ if self.0.len() != stride.len() {
+ return false;
+ }
+ let mut acc = 1;
+ for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
+ if stride != acc {
+ return false;
+ }
+ acc *= dim;
+ }
+ true
+ }
}
#[cfg(test)]
diff --git a/src/storage.rs b/src/storage.rs
index 7083cc28..573cf945 100644
--- a/src/storage.rs
+++ b/src/storage.rs
@@ -1,9 +1,9 @@
-use crate::{CpuStorage, DType, Device, Error, Result, Shape};
+use crate::{CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
#[derive(Debug, Clone)]
pub enum Storage {
Cpu(CpuStorage),
- Cuda { gpu_id: usize }, // TODO: Actually add the storage.
+ Cuda(CudaStorage),
}
pub(crate) trait UnaryOp {
@@ -100,20 +100,20 @@ impl Storage {
pub fn device(&self) -> Device {
match self {
Self::Cpu(_) => Device::Cpu,
- Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id },
+ Self::Cuda(storage) => Device::Cuda(storage.device()),
}
}
pub fn dtype(&self) -> DType {
match self {
Self::Cpu(storage) => storage.dtype(),
- Self::Cuda { .. } => todo!(),
+ Self::Cuda(storage) => storage.dtype(),
}
}
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
- let lhs = self.device();
- let rhs = rhs.device();
+ let lhs = self.device().location();
+ let rhs = rhs.device().location();
if lhs != rhs {
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
} else {
@@ -144,7 +144,10 @@ impl Storage {
let storage = storage.affine_impl(shape, stride, mul, add)?;
Ok(Self::Cpu(storage))
}
- Self::Cuda { .. } => todo!(),
+ Self::Cuda(storage) => {
+ let storage = storage.affine_impl(shape, stride, mul, add)?;
+ Ok(Self::Cuda(storage))
+ }
}
}
@@ -179,8 +182,8 @@ impl Storage {
// Should not happen because of the same device check above but we're defensive
// anyway.
Err(Error::DeviceMismatchBinaryOp {
- lhs: lhs.device(),
- rhs: rhs.device(),
+ lhs: lhs.device().location(),
+ rhs: rhs.device().location(),
op: B::NAME,
})
}
diff --git a/src/tensor.rs b/src/tensor.rs
index 9ba412f9..02105573 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -84,7 +84,7 @@ impl Tensor {
fn ones_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
- device: Device,
+ device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
@@ -101,22 +101,22 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
- pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
+ pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::ones_impl(shape, dtype, device, false)
}
- pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
+ pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::ones_impl(shape, dtype, device, true)
}
pub fn ones_like(&self) -> Result<Self> {
- Tensor::ones(self.shape(), self.dtype(), self.device())
+ Tensor::ones(self.shape(), self.dtype(), &self.device())
}
fn zeros_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
- device: Device,
+ device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
@@ -133,21 +133,21 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
- pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
+ pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, false)
}
- pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
+ pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, true)
}
pub fn zeros_like(&self) -> Result<Self> {
- Tensor::zeros(self.shape(), self.dtype(), self.device())
+ Tensor::zeros(self.shape(), self.dtype(), &self.device())
}
pub fn new_impl<A: crate::device::NdArray>(
array: A,
- device: Device,
+ device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = array.shape()?;
@@ -164,11 +164,11 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
- pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
+ pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
Self::new_impl(array, device, false)
}
- pub fn var<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
+ pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
Self::new_impl(array, device, true)
}
@@ -250,7 +250,12 @@ impl Tensor {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(self.strided_index().map(|i| data[i]).collect())
}
- Storage::Cuda { .. } => todo!(),
+ Storage::Cuda(slice) => {
+ // TODO: Would it be possible to only fetch the necessary data?
+ let cpu_storage = slice.to_cpu_storage()?;
+ let data = S::cpu_storage_as_slice(&cpu_storage)?;
+ Ok(self.strided_index().map(|i| data[i]).collect())
+ }
}
}
@@ -305,14 +310,7 @@ impl Tensor {
}
pub fn is_contiguous(&self) -> bool {
- let mut acc = 1;
- for (&stride, &dim) in self.stride.iter().zip(self.shape.dims().iter()).rev() {
- if stride != acc {
- return false;
- }
- acc *= dim;
- }
- true
+ self.shape.is_contiguous(&self.stride)
}
/// Return all the nodes that lead to this value in a topologically sorted vec, the first