diff options
-rw-r--r-- | examples/basics.rs | 9 | ||||
-rw-r--r-- | src/dtype.rs | 54 |
2 files changed, 30 insertions, 33 deletions
diff --git a/examples/basics.rs b/examples/basics.rs new file mode 100644 index 00000000..f01f7871 --- /dev/null +++ b/examples/basics.rs @@ -0,0 +1,9 @@ +use anyhow::Result; +use candle::{Device, Tensor}; + +fn main() -> Result<()> { + let x = Tensor::var(&[3f32, 1., 4.], Device::Cpu)?; + let y = (((&x * &x)? + &x * 5f64)? + 4f64)?; + println!("{:?}", y.to_vec1::<f32>()?); + Ok(()) +} diff --git a/src/dtype.rs b/src/dtype.rs index d66d046c..fd0eaa1b 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -27,38 +27,26 @@ pub trait WithDType: Sized + Copy { fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; } -impl WithDType for f32 { - const DTYPE: DType = DType::F32; - - fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage { - CpuStorage::F32(data) - } - - fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { - match s { - CpuStorage::F32(data) => Ok(data), - _ => Err(Error::UnexpectedDType { - expected: DType::F32, - got: s.dtype(), - }), +macro_rules! with_dtype { + ($ty:ty, $dtype:ident) => { + impl WithDType for $ty { + const DTYPE: DType = DType::$dtype; + + fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage { + CpuStorage::$dtype(data) + } + + fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { + match s { + CpuStorage::$dtype(data) => Ok(data), + _ => Err(Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + }), + } + } } - } -} - -impl WithDType for f64 { - const DTYPE: DType = DType::F64; - - fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage { - CpuStorage::F64(data) - } - - fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { - match s { - CpuStorage::F64(data) => Ok(data), - _ => Err(Error::UnexpectedDType { - expected: DType::F64, - got: s.dtype(), - }), - } - } + }; } +with_dtype!(f32, F32); +with_dtype!(f64, F64); |