use crate::{CpuStorage, DType, Device, Error, Result, Shape}; #[derive(Debug, Clone)] pub enum Storage { Cpu(CpuStorage), Cuda { gpu_id: usize }, // TODO: Actually add the storage. } pub(crate) trait UnaryOp { const NAME: &'static str; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; } pub(crate) trait BinaryOp { const NAME: &'static str; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; } struct Add; struct Div; struct Mul; struct Sub; struct Neg; struct Sqr; struct Sqrt; impl BinaryOp for Add { const NAME: &'static str = "add"; fn f32(v1: f32, v2: f32) -> f32 { v1 + v2 } fn f64(v1: f64, v2: f64) -> f64 { v1 + v2 } } impl BinaryOp for Sub { const NAME: &'static str = "sub"; fn f32(v1: f32, v2: f32) -> f32 { v1 - v2 } fn f64(v1: f64, v2: f64) -> f64 { v1 - v2 } } impl BinaryOp for Mul { const NAME: &'static str = "mul"; fn f32(v1: f32, v2: f32) -> f32 { v1 * v2 } fn f64(v1: f64, v2: f64) -> f64 { v1 * v2 } } impl BinaryOp for Div { const NAME: &'static str = "div"; fn f32(v1: f32, v2: f32) -> f32 { v1 / v2 } fn f64(v1: f64, v2: f64) -> f64 { v1 / v2 } } impl UnaryOp for Neg { const NAME: &'static str = "neg"; fn f32(v1: f32) -> f32 { -v1 } fn f64(v1: f64) -> f64 { -v1 } } impl UnaryOp for Sqr { const NAME: &'static str = "sqr"; fn f32(v1: f32) -> f32 { v1 * v1 } fn f64(v1: f64) -> f64 { v1 * v1 } } impl UnaryOp for Sqrt { const NAME: &'static str = "sqrt"; fn f32(v1: f32) -> f32 { v1.sqrt() } fn f64(v1: f64) -> f64 { v1.sqrt() } } impl Storage { pub fn device(&self) -> Device { match self { Self::Cpu(_) => Device::Cpu, Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id }, } } pub fn dtype(&self) -> DType { match self { Self::Cpu(storage) => storage.dtype(), Self::Cuda { .. } => todo!(), } } pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { let lhs = self.device(); let rhs = rhs.device(); if lhs != rhs { Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }) } else { Ok(()) } } pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> { let lhs = self.dtype(); let rhs = rhs.dtype(); if lhs != rhs { Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }) } else { Ok(()) } } pub(crate) fn affine_impl( &self, shape: &Shape, stride: &[usize], mul: f64, add: f64, ) -> Result { // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { let storage = storage.affine_impl(shape, stride, mul, add)?; Ok(Self::Cpu(storage)) } Self::Cuda { .. } => todo!(), } } fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { let storage = storage.unary_impl::(shape, stride)?; Ok(Self::Cpu(storage)) } Self::Cuda { .. } => todo!(), } } // TODO: Support broadcasting? fn binary_impl( &self, rhs: &Self, shape: &Shape, lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { self.same_device(rhs, B::NAME)?; self.same_dtype(rhs, B::NAME)?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; Ok(Self::Cpu(storage)) } (Self::Cuda { .. }, Self::Cuda { .. }) => todo!(), (lhs, rhs) => { // Should not happen because of the same device check above but we're defensive // anyway. Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device(), rhs: rhs.device(), op: B::NAME, }) } } } pub(crate) fn add_impl( &self, rhs: &Self, shape: &Shape, lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) } pub(crate) fn sub_impl( &self, rhs: &Self, shape: &Shape, lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) } pub(crate) fn mul_impl( &self, rhs: &Self, shape: &Shape, lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) } pub(crate) fn div_impl( &self, rhs: &Self, shape: &Shape, lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { self.binary_impl::
(rhs, shape, lhs_stride, rhs_stride) } pub(crate) fn neg_impl(&self, shape: &Shape, stride: &[usize]) -> Result { self.unary_impl::(shape, stride) } pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result { self.unary_impl::(shape, stride) } pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result { self.unary_impl::(shape, stride) } }