summaryrefslogtreecommitdiff
path: root/src/storage.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/storage.rs')
-rw-r--r--src/storage.rs48
1 files changed, 47 insertions, 1 deletions
diff --git a/src/storage.rs b/src/storage.rs
index f1a2d5a0..22f9a26c 100644
--- a/src/storage.rs
+++ b/src/storage.rs
@@ -16,6 +16,13 @@ pub(crate) trait BinaryOp {
const NAME: &'static str;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
+ fn cuda_impl(
+ lhs: &CudaStorage,
+ rhs: &CudaStorage,
+ shape: &Shape,
+ lhs_stride: &[usize],
+ rhs_stride: &[usize],
+ ) -> Result<CudaStorage>;
}
struct Add;
@@ -34,6 +41,15 @@ impl BinaryOp for Add {
fn f64(v1: f64, v2: f64) -> f64 {
v1 + v2
}
+ fn cuda_impl(
+ lhs: &CudaStorage,
+ rhs: &CudaStorage,
+ shape: &Shape,
+ lhs_stride: &[usize],
+ rhs_stride: &[usize],
+ ) -> Result<CudaStorage> {
+ Ok(lhs.add_impl(rhs, shape, lhs_stride, rhs_stride)?)
+ }
}
impl BinaryOp for Sub {
@@ -44,6 +60,15 @@ impl BinaryOp for Sub {
fn f64(v1: f64, v2: f64) -> f64 {
v1 - v2
}
+ fn cuda_impl(
+ _: &CudaStorage,
+ _: &CudaStorage,
+ _: &Shape,
+ _: &[usize],
+ _: &[usize],
+ ) -> Result<CudaStorage> {
+ todo!()
+ }
}
impl BinaryOp for Mul {
@@ -54,6 +79,15 @@ impl BinaryOp for Mul {
fn f64(v1: f64, v2: f64) -> f64 {
v1 * v2
}
+ fn cuda_impl(
+ _: &CudaStorage,
+ _: &CudaStorage,
+ _: &Shape,
+ _: &[usize],
+ _: &[usize],
+ ) -> Result<CudaStorage> {
+ todo!()
+ }
}
impl BinaryOp for Div {
@@ -64,6 +98,15 @@ impl BinaryOp for Div {
fn f64(v1: f64, v2: f64) -> f64 {
v1 / v2
}
+ fn cuda_impl(
+ _: &CudaStorage,
+ _: &CudaStorage,
+ _: &Shape,
+ _: &[usize],
+ _: &[usize],
+ ) -> Result<CudaStorage> {
+ todo!()
+ }
}
impl UnaryOp for Neg {
@@ -177,7 +220,10 @@ impl Storage {
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
Ok(Self::Cpu(storage))
}
- (Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
+ (Self::Cuda(lhs), Self::Cuda(rhs)) => {
+ let storage = B::cuda_impl(lhs, rhs, shape, lhs_stride, rhs_stride)?;
+ Ok(Self::Cuda(storage))
+ }
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.