summaryrefslogtreecommitdiff
path: root/candle-core/benches/bench_utils.rs
diff options
context:
space:
mode:
authorIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-07 14:40:15 +0100
committerIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-07 14:40:15 +0100
commit3f04a79ada7ca974176a0c7c3c3306f394eae9a9 (patch)
tree74daca4c8acfe596d3039a3f042e7d9b380a249b /candle-core/benches/bench_utils.rs
parent84250bf52f58528cf59dca3b82effd9f07a13cc7 (diff)
downloadcandle-3f04a79ada7ca974176a0c7c3c3306f394eae9a9.tar.gz
candle-3f04a79ada7ca974176a0c7c3c3306f394eae9a9.tar.bz2
candle-3f04a79ada7ca974176a0c7c3c3306f394eae9a9.zip
Use cfg to seperate benchmark results based on features
Diffstat (limited to 'candle-core/benches/bench_utils.rs')
-rw-r--r--candle-core/benches/bench_utils.rs56
1 files changed, 56 insertions, 0 deletions
diff --git a/candle-core/benches/bench_utils.rs b/candle-core/benches/bench_utils.rs
new file mode 100644
index 00000000..75800761
--- /dev/null
+++ b/candle-core/benches/bench_utils.rs
@@ -0,0 +1,56 @@
+use candle_core::{Device, Result};
+
+pub(crate) trait BenchDevice {
+ fn sync(&self) -> Result<()>;
+}
+
+impl BenchDevice for Device {
+ fn sync(&self) -> Result<()> {
+ match self {
+ Device::Cpu => Ok(()),
+ Device::Cuda(device) => {
+ #[cfg(feature = "cuda")]
+ return Ok(device.synchronize()?);
+ #[cfg(not(feature = "cuda"))]
+ panic!("Cuda device without cuda feature enabled: {:?}", device)
+ }
+ Device::Metal(device) => {
+ #[cfg(feature = "metal")]
+ return Ok(device.wait_until_completed()?);
+ #[cfg(not(feature = "metal"))]
+ panic!("Metal device without metal feature enabled: {:?}", device)
+ }
+ }
+ }
+}
+
+#[allow(dead_code)]
+pub(crate) fn device() -> Result<Device> {
+ return if cfg!(feature = "metal") {
+ Device::new_metal(0)
+ } else if cfg!(feature = "cuda") {
+ Device::new_cuda(0)
+ } else {
+ Ok(Device::Cpu)
+ };
+}
+
+#[allow(dead_code)]
+pub(crate) fn bench_name<S: Into<String>>(name: S) -> String {
+ format!("{}_{}", device_variant(), name.into())
+}
+
+#[allow(dead_code)]
+const fn device_variant() -> &'static str {
+ return if cfg!(feature = "metal") {
+ "metal"
+ } else if cfg!(feature = "cuda") {
+ "cuda"
+ } else if cfg!(feature = "accelerate") {
+ "accelerate"
+ } else if cfg!(feature = "mkl") {
+ "mkl"
+ } else {
+ "cpu"
+ };
+}