summaryrefslogtreecommitdiff
path: root/src/storage.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/storage.rs')
-rw-r--r--src/storage.rs15
1 files changed, 15 insertions, 0 deletions
diff --git a/src/storage.rs b/src/storage.rs
index 463788d4..30161a2c 100644
--- a/src/storage.rs
+++ b/src/storage.rs
@@ -20,6 +20,7 @@ impl CpuStorage {
#[derive(Debug, Clone)]
pub enum Storage {
Cpu(CpuStorage),
+ Cuda { gpu_id: usize }, // TODO: Actually add the storage.
}
trait UnaryOp {
@@ -116,12 +117,14 @@ impl Storage {
pub fn device(&self) -> Device {
match self {
Self::Cpu(_) => Device::Cpu,
+ Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id },
}
}
pub fn dtype(&self) -> DType {
match self {
Self::Cpu(storage) => storage.dtype(),
+ Self::Cuda { .. } => todo!(),
}
}
@@ -168,6 +171,7 @@ impl Storage {
Ok(Storage::Cpu(CpuStorage::F64(data)))
}
},
+ Self::Cuda { .. } => todo!(),
}
}
@@ -186,6 +190,7 @@ impl Storage {
Ok(Storage::Cpu(CpuStorage::F64(data)))
}
},
+ Self::Cuda { .. } => todo!(),
}
}
@@ -232,6 +237,16 @@ impl Storage {
})
}
},
+ (Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
+ (lhs, rhs) => {
+ // Should not happen because of the same device check above but we're defensive
+ // anyway.
+ Err(Error::DeviceMismatchBinaryOp {
+ lhs: lhs.device(),
+ rhs: rhs.device(),
+ op: B::NAME,
+ })
+ }
}
}