diff options
Diffstat (limited to 'src/storage.rs')
-rw-r--r-- | src/storage.rs | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/src/storage.rs b/src/storage.rs index 463788d4..30161a2c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -20,6 +20,7 @@ impl CpuStorage { #[derive(Debug, Clone)] pub enum Storage { Cpu(CpuStorage), + Cuda { gpu_id: usize }, // TODO: Actually add the storage. } trait UnaryOp { @@ -116,12 +117,14 @@ impl Storage { pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, + Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id }, } } pub fn dtype(&self) -> DType { match self { Self::Cpu(storage) => storage.dtype(), + Self::Cuda { .. } => todo!(), } } @@ -168,6 +171,7 @@ impl Storage { Ok(Storage::Cpu(CpuStorage::F64(data))) } }, + Self::Cuda { .. } => todo!(), } } @@ -186,6 +190,7 @@ impl Storage { Ok(Storage::Cpu(CpuStorage::F64(data))) } }, + Self::Cuda { .. } => todo!(), } } @@ -232,6 +237,16 @@ impl Storage { }) } }, + (Self::Cuda { .. }, Self::Cuda { .. }) => todo!(), + (lhs, rhs) => { + // Should not happen because of the same device check above but we're defensive + // anyway. + Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device(), + rhs: rhs.device(), + op: B::NAME, + }) + } } } |