summaryrefslogtreecommitdiff
path: root/src/storage.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/storage.rs')
-rw-r--r--src/storage.rs21
1 files changed, 12 insertions, 9 deletions
diff --git a/src/storage.rs b/src/storage.rs
index 7083cc28..573cf945 100644
--- a/src/storage.rs
+++ b/src/storage.rs
@@ -1,9 +1,9 @@
-use crate::{CpuStorage, DType, Device, Error, Result, Shape};
+use crate::{CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
#[derive(Debug, Clone)]
pub enum Storage {
Cpu(CpuStorage),
- Cuda { gpu_id: usize }, // TODO: Actually add the storage.
+ Cuda(CudaStorage),
}
pub(crate) trait UnaryOp {
@@ -100,20 +100,20 @@ impl Storage {
pub fn device(&self) -> Device {
match self {
Self::Cpu(_) => Device::Cpu,
- Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id },
+ Self::Cuda(storage) => Device::Cuda(storage.device()),
}
}
pub fn dtype(&self) -> DType {
match self {
Self::Cpu(storage) => storage.dtype(),
- Self::Cuda { .. } => todo!(),
+ Self::Cuda(storage) => storage.dtype(),
}
}
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
- let lhs = self.device();
- let rhs = rhs.device();
+ let lhs = self.device().location();
+ let rhs = rhs.device().location();
if lhs != rhs {
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
} else {
@@ -144,7 +144,10 @@ impl Storage {
let storage = storage.affine_impl(shape, stride, mul, add)?;
Ok(Self::Cpu(storage))
}
- Self::Cuda { .. } => todo!(),
+ Self::Cuda(storage) => {
+ let storage = storage.affine_impl(shape, stride, mul, add)?;
+ Ok(Self::Cuda(storage))
+ }
}
}
@@ -179,8 +182,8 @@ impl Storage {
// Should not happen because of the same device check above but we're defensive
// anyway.
Err(Error::DeviceMismatchBinaryOp {
- lhs: lhs.device(),
- rhs: rhs.device(),
+ lhs: lhs.device().location(),
+ rhs: rhs.device().location(),
op: B::NAME,
})
}