summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/basics.rs9
-rw-r--r--src/dtype.rs54
2 files changed, 30 insertions, 33 deletions
diff --git a/examples/basics.rs b/examples/basics.rs
new file mode 100644
index 00000000..f01f7871
--- /dev/null
+++ b/examples/basics.rs
@@ -0,0 +1,9 @@
+use anyhow::Result;
+use candle::{Device, Tensor};
+
+fn main() -> Result<()> {
+ let x = Tensor::var(&[3f32, 1., 4.], Device::Cpu)?;
+ let y = (((&x * &x)? + &x * 5f64)? + 4f64)?;
+ println!("{:?}", y.to_vec1::<f32>()?);
+ Ok(())
+}
diff --git a/src/dtype.rs b/src/dtype.rs
index d66d046c..fd0eaa1b 100644
--- a/src/dtype.rs
+++ b/src/dtype.rs
@@ -27,38 +27,26 @@ pub trait WithDType: Sized + Copy {
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
}
-impl WithDType for f32 {
- const DTYPE: DType = DType::F32;
-
- fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
- CpuStorage::F32(data)
- }
-
- fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
- match s {
- CpuStorage::F32(data) => Ok(data),
- _ => Err(Error::UnexpectedDType {
- expected: DType::F32,
- got: s.dtype(),
- }),
+macro_rules! with_dtype {
+ ($ty:ty, $dtype:ident) => {
+ impl WithDType for $ty {
+ const DTYPE: DType = DType::$dtype;
+
+ fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
+ CpuStorage::$dtype(data)
+ }
+
+ fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
+ match s {
+ CpuStorage::$dtype(data) => Ok(data),
+ _ => Err(Error::UnexpectedDType {
+ expected: DType::$dtype,
+ got: s.dtype(),
+ }),
+ }
+ }
}
- }
-}
-
-impl WithDType for f64 {
- const DTYPE: DType = DType::F64;
-
- fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
- CpuStorage::F64(data)
- }
-
- fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
- match s {
- CpuStorage::F64(data) => Ok(data),
- _ => Err(Error::UnexpectedDType {
- expected: DType::F64,
- got: s.dtype(),
- }),
- }
- }
+ };
}
+with_dtype!(f32, F32);
+with_dtype!(f64, F64);