summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/device.rs10
-rw-r--r--src/storage.rs15
-rw-r--r--src/tensor.rs3
3 files changed, 28 insertions, 0 deletions
diff --git a/src/device.rs b/src/device.rs
index d7b724d1..c092a347 100644
--- a/src/device.rs
+++ b/src/device.rs
@@ -6,6 +6,7 @@ use crate::{
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Device {
Cpu,
+ Cuda { gpu_id: usize },
}
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
@@ -72,6 +73,9 @@ impl Device {
};
Storage::Cpu(storage)
}
+ Device::Cuda { gpu_id: _ } => {
+ todo!()
+ }
}
}
@@ -91,12 +95,18 @@ impl Device {
};
Storage::Cpu(storage)
}
+ Device::Cuda { gpu_id: _ } => {
+ todo!()
+ }
}
}
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
match self {
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
+ Device::Cuda { gpu_id: _ } => {
+ todo!()
+ }
}
}
}
diff --git a/src/storage.rs b/src/storage.rs
index 463788d4..30161a2c 100644
--- a/src/storage.rs
+++ b/src/storage.rs
@@ -20,6 +20,7 @@ impl CpuStorage {
#[derive(Debug, Clone)]
pub enum Storage {
Cpu(CpuStorage),
+ Cuda { gpu_id: usize }, // TODO: Actually add the storage.
}
trait UnaryOp {
@@ -116,12 +117,14 @@ impl Storage {
pub fn device(&self) -> Device {
match self {
Self::Cpu(_) => Device::Cpu,
+ Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id },
}
}
pub fn dtype(&self) -> DType {
match self {
Self::Cpu(storage) => storage.dtype(),
+ Self::Cuda { .. } => todo!(),
}
}
@@ -168,6 +171,7 @@ impl Storage {
Ok(Storage::Cpu(CpuStorage::F64(data)))
}
},
+ Self::Cuda { .. } => todo!(),
}
}
@@ -186,6 +190,7 @@ impl Storage {
Ok(Storage::Cpu(CpuStorage::F64(data)))
}
},
+ Self::Cuda { .. } => todo!(),
}
}
@@ -232,6 +237,16 @@ impl Storage {
})
}
},
+ (Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
+ (lhs, rhs) => {
+ // Should not happen because of the same device check above but we're defensive
+ // anyway.
+ Err(Error::DeviceMismatchBinaryOp {
+ lhs: lhs.device(),
+ rhs: rhs.device(),
+ op: B::NAME,
+ })
+ }
}
}
diff --git a/src/tensor.rs b/src/tensor.rs
index b8fa738a..2d704a65 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -209,6 +209,7 @@ impl Tensor {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(data[0])
}
+ Storage::Cuda { .. } => todo!(),
}
}
@@ -249,6 +250,7 @@ impl Tensor {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(self.strided_index().map(|i| data[i]).collect())
}
+ Storage::Cuda { .. } => todo!(),
}
}
@@ -266,6 +268,7 @@ impl Tensor {
assert!(src_index.next().is_none());
Ok(rows)
}
+ Storage::Cuda { .. } => todo!(),
}
}