diff options
Diffstat (limited to 'candle-core/src/backend.rs')
-rw-r--r-- | candle-core/src/backend.rs | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs new file mode 100644 index 00000000..aa35703d --- /dev/null +++ b/candle-core/src/backend.rs @@ -0,0 +1,71 @@ +use crate::{CpuStorage, DType, Layout, Result, Shape}; + +pub(crate) trait BackendStorage: Sized { + type Device: BackendDevice; + + fn try_clone(&self, _: &Layout) -> Result<Self>; + + fn dtype(&self) -> DType; + + fn device(&self) -> &Self::Device; + + fn to_cpu_storage(&self) -> Result<CpuStorage>; + + fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>; + + fn elu(&self, _: &Layout, _: f64) -> Result<Self>; + + fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self>; + + fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>; + + fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>; + + fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self>; + + fn binary_impl<B: crate::op::BinaryOp>(&self, _: &Self, _: &Layout, _: &Layout) + -> Result<Self>; + + fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>; + + fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv1D, + ) -> Result<Self>; + + fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>; + + fn matmul( + &self, + _: &Self, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result<Self>; + + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>; +} + +pub(crate) trait BackendDevice: Sized + std::fmt::Debug + Clone { + type Storage: BackendStorage; + + // TODO: Make the usize generic and part of a generic DeviceLocation. + fn new(_: usize) -> Result<Self>; + + fn location(&self) -> crate::DeviceLocation; + + fn same_device(&self, _: &Self) -> bool; + + fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>; + + fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>; + + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>; + + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; + + fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; +} |