diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-21 09:13:57 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-21 09:13:57 +0100 |
commit | 8cde0c54788d7ae7c676e4f2fad5fcbc16f6980c (patch) | |
tree | 265aa584ac6cb9024716e08e38b61fe454431823 /src/storage.rs | |
parent | f319583530745dfab125bd2d16c2dfa4aa75646d (diff) | |
download | candle-8cde0c54788d7ae7c676e4f2fad5fcbc16f6980c.tar.gz candle-8cde0c54788d7ae7c676e4f2fad5fcbc16f6980c.tar.bz2 candle-8cde0c54788d7ae7c676e4f2fad5fcbc16f6980c.zip |
Add some skeleton code for GPU support.
Diffstat (limited to 'src/storage.rs')
-rw-r--r-- | src/storage.rs | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/src/storage.rs b/src/storage.rs index 463788d4..30161a2c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -20,6 +20,7 @@ impl CpuStorage { #[derive(Debug, Clone)] pub enum Storage { Cpu(CpuStorage), + Cuda { gpu_id: usize }, // TODO: Actually add the storage. } trait UnaryOp { @@ -116,12 +117,14 @@ impl Storage { pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, + Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id }, } } pub fn dtype(&self) -> DType { match self { Self::Cpu(storage) => storage.dtype(), + Self::Cuda { .. } => todo!(), } } @@ -168,6 +171,7 @@ impl Storage { Ok(Storage::Cpu(CpuStorage::F64(data))) } }, + Self::Cuda { .. } => todo!(), } } @@ -186,6 +190,7 @@ impl Storage { Ok(Storage::Cpu(CpuStorage::F64(data))) } }, + Self::Cuda { .. } => todo!(), } } @@ -232,6 +237,16 @@ impl Storage { }) } }, + (Self::Cuda { .. }, Self::Cuda { .. }) => todo!(), + (lhs, rhs) => { + // Should not happen because of the same device check above but we're defensive + // anyway. + Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device(), + rhs: rhs.device(), + op: B::NAME, + }) + } } } |