summaryrefslogtreecommitdiff
path: root/candle-core/examples/basics.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/examples/basics.rs')
-rw-r--r--candle-core/examples/basics.rs29
1 files changed, 9 insertions, 20 deletions
diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs
index d028db66..efce913a 100644
--- a/candle-core/examples/basics.rs
+++ b/candle-core/examples/basics.rs
@@ -1,29 +1,18 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
use anyhow::Result;
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
- let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
- let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
- let c = a.matmul(&b)?;
- println!("{a} {b} {c}");
-
- let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
- let t1 = Tensor::new(data, &Device::Cpu)?;
- let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
- let t2 = Tensor::new(data2, &Device::Cpu)?;
- assert_eq!(
- Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
- .t()?
- .to_vec2::<f32>()?,
- [
- [3.0, 1.0, 4.0, 1.0, 5.0],
- [2.0, 7.0, 1.0, 8.0, 2.0],
- [5.0, 5.0, 5.0, 5.0, 5.0],
- [2.0, 7.0, 1.0, 8.0, 2.0]
- ]
- );
+ let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
+ let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
+ let start = std::time::Instant::now();
+ let res = inp.conv2d(&w, 0, 1);
+ println!("{:?}", start.elapsed());
+ println!("{res:?}");
Ok(())
}