diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-07-09 12:38:11 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-09 12:38:11 +0200 |
commit | 25960676caefcb10060fb36a8d66efa9fa731dec (patch) | |
tree | 6f2f10be8bb7389cb2dda3e9e5c0cd7bff35c64f /candle-core | |
parent | 9cd54aa5d4fb6cf30e0df2d198c8a387db2d9144 (diff) | |
download | candle-25960676caefcb10060fb36a8d66efa9fa731dec.tar.gz candle-25960676caefcb10060fb36a8d66efa9fa731dec.tar.bz2 candle-25960676caefcb10060fb36a8d66efa9fa731dec.zip |
Add a basic metal example with capture (#2324)
* Add some tracing.
* Get the trace to work.
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/Cargo.toml | 4 | ||||
-rw-r--r-- | candle-core/examples/metal_basics.rs | 28 | ||||
-rw-r--r-- | candle-core/src/metal_backend/device.rs | 8 |
3 files changed, 39 insertions, 1 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 92a04917..cbf8f200 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -48,3 +48,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"] [[bench]] name = "bench_main" harness = false + +[[example]] +name = "metal_basics" +required-features = ["metal"] diff --git a/candle-core/examples/metal_basics.rs b/candle-core/examples/metal_basics.rs new file mode 100644 index 00000000..f9ff81ad --- /dev/null +++ b/candle-core/examples/metal_basics.rs @@ -0,0 +1,28 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::Result; +use candle_core::{Device, Tensor}; + +fn main() -> Result<()> { + // This requires the code to be run with MTL_CAPTURE_ENABLED=1 + let device = Device::new_metal(0)?; + let metal_device = match &device { + Device::Metal(m) => m, + _ => anyhow::bail!("unexpected device"), + }; + metal_device.capture("/tmp/candle.gputrace")?; + // This first synchronize ensures that a new command buffer gets created after setting up the + // capture scope. + device.synchronize()?; + let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?; + let x1 = x.add(&x)?; + println!("{x1:?}"); + // This second synchronize ensures that the command buffer gets commited before the end of the + // capture scope. + device.synchronize()?; + Ok(()) +} diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 785fe621..07210c68 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -273,7 +273,13 @@ impl MetalDevice { let descriptor = metal::CaptureDescriptor::new(); descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); descriptor.set_capture_device(self); - descriptor.set_output_url(path); + // The [set_output_url] call requires an absolute path so we convert it if needed. + if path.as_ref().is_absolute() { + descriptor.set_output_url(path); + } else { + let path = std::env::current_dir()?.join(path); + descriptor.set_output_url(path); + } capture .start_capture(&descriptor) |