diff options
Diffstat (limited to 'candle-nn/benches/benchmarks/mod.rs')
-rw-r--r-- | candle-nn/benches/benchmarks/mod.rs | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs new file mode 100644 index 00000000..30a6ab6a --- /dev/null +++ b/candle-nn/benches/benchmarks/mod.rs @@ -0,0 +1,64 @@ +pub(crate) mod conv; +pub(crate) mod layer_norm; + +use candle::{Device, Result}; + +pub(crate) trait BenchDevice { + fn sync(&self) -> Result<()>; + + fn bench_name<S: Into<String>>(&self, name: S) -> String; +} + +impl BenchDevice for Device { + fn sync(&self) -> Result<()> { + match self { + Device::Cpu => Ok(()), + Device::Cuda(device) => { + #[cfg(feature = "cuda")] + return Ok(device.synchronize()?); + #[cfg(not(feature = "cuda"))] + panic!("Cuda device without cuda feature enabled: {:?}", device) + } + Device::Metal(device) => { + #[cfg(feature = "metal")] + return Ok(device.wait_until_completed()?); + #[cfg(not(feature = "metal"))] + panic!("Metal device without metal feature enabled: {:?}", device) + } + } + } + + fn bench_name<S: Into<String>>(&self, name: S) -> String { + match self { + Device::Cpu => { + let cpu_type = if cfg!(feature = "accelerate") { + "accelerate" + } else if cfg!(feature = "mkl") { + "mkl" + } else { + "cpu" + }; + format!("{}_{}", cpu_type, name.into()) + } + Device::Cuda(_) => format!("cuda_{}", name.into()), + Device::Metal(_) => format!("metal_{}", name.into()), + } + } +} + +struct BenchDeviceHandler { + devices: Vec<Device>, +} + +impl BenchDeviceHandler { + pub fn new() -> Result<Self> { + let mut devices = Vec::new(); + if cfg!(feature = "metal") { + devices.push(Device::new_metal(0)?); + } else if cfg!(feature = "cuda") { + devices.push(Device::new_cuda(0)?); + } + devices.push(Device::Cpu); + Ok(Self { devices }) + } +} |