diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-21 10:08:41 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-21 10:08:41 +0100 |
commit | b3eb57cd0a696ec184e47c7316871b01e0a45aea (patch) | |
tree | da25dd7a4a6675841aa2c91cf8a6267a55f68722 /src | |
parent | 8cde0c54788d7ae7c676e4f2fad5fcbc16f6980c (diff) | |
download | candle-b3eb57cd0a696ec184e47c7316871b01e0a45aea.tar.gz candle-b3eb57cd0a696ec184e47c7316871b01e0a45aea.tar.bz2 candle-b3eb57cd0a696ec184e47c7316871b01e0a45aea.zip |
Avoid some duplication using a macro + add some basic example to make debugging easier.
Diffstat (limited to 'src')
-rw-r--r-- | src/dtype.rs | 54 |
1 files changed, 21 insertions, 33 deletions
diff --git a/src/dtype.rs b/src/dtype.rs index d66d046c..fd0eaa1b 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -27,38 +27,26 @@ pub trait WithDType: Sized + Copy { fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; } -impl WithDType for f32 { - const DTYPE: DType = DType::F32; - - fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage { - CpuStorage::F32(data) - } - - fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { - match s { - CpuStorage::F32(data) => Ok(data), - _ => Err(Error::UnexpectedDType { - expected: DType::F32, - got: s.dtype(), - }), +macro_rules! with_dtype { + ($ty:ty, $dtype:ident) => { + impl WithDType for $ty { + const DTYPE: DType = DType::$dtype; + + fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage { + CpuStorage::$dtype(data) + } + + fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { + match s { + CpuStorage::$dtype(data) => Ok(data), + _ => Err(Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + }), + } + } } - } -} - -impl WithDType for f64 { - const DTYPE: DType = DType::F64; - - fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage { - CpuStorage::F64(data) - } - - fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { - match s { - CpuStorage::F64(data) => Ok(data), - _ => Err(Error::UnexpectedDType { - expected: DType::F64, - got: s.dtype(), - }), - } - } + }; } +with_dtype!(f32, F32); +with_dtype!(f64, F64); |