diff options
Diffstat (limited to 'src/storage.rs')
-rw-r--r-- | src/storage.rs | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/src/storage.rs b/src/storage.rs index 7083cc28..573cf945 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,9 +1,9 @@ -use crate::{CpuStorage, DType, Device, Error, Result, Shape}; +use crate::{CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; #[derive(Debug, Clone)] pub enum Storage { Cpu(CpuStorage), - Cuda { gpu_id: usize }, // TODO: Actually add the storage. + Cuda(CudaStorage), } pub(crate) trait UnaryOp { @@ -100,20 +100,20 @@ impl Storage { pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, - Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id }, + Self::Cuda(storage) => Device::Cuda(storage.device()), } } pub fn dtype(&self) -> DType { match self { Self::Cpu(storage) => storage.dtype(), - Self::Cuda { .. } => todo!(), + Self::Cuda(storage) => storage.dtype(), } } pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { - let lhs = self.device(); - let rhs = rhs.device(); + let lhs = self.device().location(); + let rhs = rhs.device().location(); if lhs != rhs { Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }) } else { @@ -144,7 +144,10 @@ impl Storage { let storage = storage.affine_impl(shape, stride, mul, add)?; Ok(Self::Cpu(storage)) } - Self::Cuda { .. } => todo!(), + Self::Cuda(storage) => { + let storage = storage.affine_impl(shape, stride, mul, add)?; + Ok(Self::Cuda(storage)) + } } } @@ -179,8 +182,8 @@ impl Storage { // Should not happen because of the same device check above but we're defensive // anyway. Err(Error::DeviceMismatchBinaryOp { - lhs: lhs.device(), - rhs: rhs.device(), + lhs: lhs.device().location(), + rhs: rhs.device().location(), op: B::NAME, }) } |