summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-21 10:08:41 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-21 10:08:41 +0100
commitb3eb57cd0a696ec184e47c7316871b01e0a45aea (patch)
treeda25dd7a4a6675841aa2c91cf8a6267a55f68722 /src
parent8cde0c54788d7ae7c676e4f2fad5fcbc16f6980c (diff)
downloadcandle-b3eb57cd0a696ec184e47c7316871b01e0a45aea.tar.gz
candle-b3eb57cd0a696ec184e47c7316871b01e0a45aea.tar.bz2
candle-b3eb57cd0a696ec184e47c7316871b01e0a45aea.zip
Avoid some duplication using a macro + add some basic example to make debugging easier.
Diffstat (limited to 'src')
-rw-r--r--src/dtype.rs54
1 files changed, 21 insertions, 33 deletions
diff --git a/src/dtype.rs b/src/dtype.rs
index d66d046c..fd0eaa1b 100644
--- a/src/dtype.rs
+++ b/src/dtype.rs
@@ -27,38 +27,26 @@ pub trait WithDType: Sized + Copy {
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
}
-impl WithDType for f32 {
- const DTYPE: DType = DType::F32;
-
- fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
- CpuStorage::F32(data)
- }
-
- fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
- match s {
- CpuStorage::F32(data) => Ok(data),
- _ => Err(Error::UnexpectedDType {
- expected: DType::F32,
- got: s.dtype(),
- }),
+macro_rules! with_dtype {
+ ($ty:ty, $dtype:ident) => {
+ impl WithDType for $ty {
+ const DTYPE: DType = DType::$dtype;
+
+ fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
+ CpuStorage::$dtype(data)
+ }
+
+ fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
+ match s {
+ CpuStorage::$dtype(data) => Ok(data),
+ _ => Err(Error::UnexpectedDType {
+ expected: DType::$dtype,
+ got: s.dtype(),
+ }),
+ }
+ }
}
- }
-}
-
-impl WithDType for f64 {
- const DTYPE: DType = DType::F64;
-
- fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
- CpuStorage::F64(data)
- }
-
- fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
- match s {
- CpuStorage::F64(data) => Ok(data),
- _ => Err(Error::UnexpectedDType {
- expected: DType::F64,
- got: s.dtype(),
- }),
- }
- }
+ };
}
+with_dtype!(f32, F32);
+with_dtype!(f64, F64);