summaryrefslogtreecommitdiff
path: root/candle-core/benches/benchmarks/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/benches/benchmarks/mod.rs')
-rw-r--r--candle-core/benches/benchmarks/mod.rs55
1 files changed, 55 insertions, 0 deletions
diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs
new file mode 100644
index 00000000..1344770d
--- /dev/null
+++ b/candle-core/benches/benchmarks/mod.rs
@@ -0,0 +1,55 @@
+pub(crate) mod matmul;
+
+use candle_core::{Device, Result};
+
+pub(crate) trait BenchDevice {
+ fn sync(&self) -> Result<()>;
+}
+
+impl BenchDevice for Device {
+ fn sync(&self) -> Result<()> {
+ match self {
+ Device::Cpu => Ok(()),
+ Device::Cuda(device) => {
+ #[cfg(feature = "cuda")]
+ return Ok(device.synchronize()?);
+ #[cfg(not(feature = "cuda"))]
+ panic!("Cuda device without cuda feature enabled: {:?}", device)
+ }
+ Device::Metal(device) => {
+ #[cfg(feature = "metal")]
+ return Ok(device.wait_until_completed()?);
+ #[cfg(not(feature = "metal"))]
+ panic!("Metal device without metal feature enabled: {:?}", device)
+ }
+ }
+ }
+}
+
+pub(crate) fn device() -> Result<Device> {
+ if cfg!(feature = "metal") {
+ Device::new_metal(0)
+ } else if cfg!(feature = "cuda") {
+ Device::new_cuda(0)
+ } else {
+ Ok(Device::Cpu)
+ }
+}
+
+pub(crate) fn bench_name<S: Into<String>>(name: S) -> String {
+ format!("{}_{}", device_variant(), name.into())
+}
+
+const fn device_variant() -> &'static str {
+ if cfg!(feature = "metal") {
+ "metal"
+ } else if cfg!(feature = "cuda") {
+ "cuda"
+ } else if cfg!(feature = "accelerate") {
+ "accelerate"
+ } else if cfg!(feature = "mkl") {
+ "mkl"
+ } else {
+ "cpu"
+ }
+}