diff options
Diffstat (limited to 'src/device.rs')
-rw-r--r-- | src/device.rs | 55 |
1 files changed, 40 insertions, 15 deletions
diff --git a/src/device.rs b/src/device.rs index c76cc301..e522cd42 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,11 +1,19 @@ use crate::{CpuStorage, DType, Result, Shape, Storage}; +/// A `DeviceLocation` represents a physical device whereas multiple `Device` +/// can live on the same location (typically for cuda devices). #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum Device { +pub enum DeviceLocation { Cpu, Cuda { gpu_id: usize }, } +#[derive(Debug, Clone)] +pub enum Device { + Cpu, + Cuda(crate::CudaDevice), +} + // TODO: Should we back the cpu implementation using the NdArray crate or similar? pub trait NdArray { fn shape(&self) -> Result<Shape>; @@ -54,14 +62,31 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; } impl Device { + pub fn new_cuda(ordinal: usize) -> Result<Self> { + Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) + } + + pub fn location(&self) -> DeviceLocation { + match self { + Self::Cpu => DeviceLocation::Cpu, + Self::Cuda(device) => DeviceLocation::Cuda { + gpu_id: device.ordinal(), + }, + } + } + pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> { match self { Device::Cpu => { - let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype)); - Ok(storage) + let storage = CpuStorage::ones_impl(shape, dtype); + Ok(Storage::Cpu(storage)) } - Device::Cuda { gpu_id: _ } => { - todo!() + Device::Cuda(device) => { + // TODO: Instead of allocating memory on the host and transfering it, + // allocate some zeros on the device and use a shader to set them to 1. + let storage = CpuStorage::ones_impl(shape, dtype); + let storage = device.cuda_from_cpu_storage(&storage)?; + Ok(Storage::Cuda(storage)) } } } @@ -69,23 +94,23 @@ impl Device { pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> { match self { Device::Cpu => { - let storage = Storage::Cpu(CpuStorage::zeros_impl(shape, dtype)); - Ok(storage) + let storage = CpuStorage::zeros_impl(shape, dtype); + Ok(Storage::Cpu(storage)) } - Device::Cuda { gpu_id: _ } => { - todo!() + Device::Cuda(device) => { + let storage = device.zeros_impl(shape, dtype)?; + Ok(Storage::Cuda(storage)) } } } pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> { match self { - Device::Cpu => { - let storage = Storage::Cpu(array.to_cpu_storage()); - Ok(storage) - } - Device::Cuda { gpu_id: _ } => { - todo!() + Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), + Device::Cuda(device) => { + let storage = array.to_cpu_storage(); + let storage = device.cuda_from_cpu_storage(&storage)?; + Ok(Storage::Cuda(storage)) } } } |