summaryrefslogtreecommitdiff
path: root/candle-nn/benches/benchmarks/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/benches/benchmarks/mod.rs')
-rw-r--r--candle-nn/benches/benchmarks/mod.rs64
1 files changed, 64 insertions, 0 deletions
diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs
new file mode 100644
index 00000000..30a6ab6a
--- /dev/null
+++ b/candle-nn/benches/benchmarks/mod.rs
@@ -0,0 +1,64 @@
+pub(crate) mod conv;
+pub(crate) mod layer_norm;
+
+use candle::{Device, Result};
+
+pub(crate) trait BenchDevice {
+ fn sync(&self) -> Result<()>;
+
+ fn bench_name<S: Into<String>>(&self, name: S) -> String;
+}
+
+impl BenchDevice for Device {
+ fn sync(&self) -> Result<()> {
+ match self {
+ Device::Cpu => Ok(()),
+ Device::Cuda(device) => {
+ #[cfg(feature = "cuda")]
+ return Ok(device.synchronize()?);
+ #[cfg(not(feature = "cuda"))]
+ panic!("Cuda device without cuda feature enabled: {:?}", device)
+ }
+ Device::Metal(device) => {
+ #[cfg(feature = "metal")]
+ return Ok(device.wait_until_completed()?);
+ #[cfg(not(feature = "metal"))]
+ panic!("Metal device without metal feature enabled: {:?}", device)
+ }
+ }
+ }
+
+ fn bench_name<S: Into<String>>(&self, name: S) -> String {
+ match self {
+ Device::Cpu => {
+ let cpu_type = if cfg!(feature = "accelerate") {
+ "accelerate"
+ } else if cfg!(feature = "mkl") {
+ "mkl"
+ } else {
+ "cpu"
+ };
+ format!("{}_{}", cpu_type, name.into())
+ }
+ Device::Cuda(_) => format!("cuda_{}", name.into()),
+ Device::Metal(_) => format!("metal_{}", name.into()),
+ }
+ }
+}
+
+struct BenchDeviceHandler {
+ devices: Vec<Device>,
+}
+
+impl BenchDeviceHandler {
+ pub fn new() -> Result<Self> {
+ let mut devices = Vec::new();
+ if cfg!(feature = "metal") {
+ devices.push(Device::new_metal(0)?);
+ } else if cfg!(feature = "cuda") {
+ devices.push(Device::new_cuda(0)?);
+ }
+ devices.push(Device::Cpu);
+ Ok(Self { devices })
+ }
+}