summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-07-09 12:38:11 +0200
committerGitHub <noreply@github.com>2024-07-09 12:38:11 +0200
commit25960676caefcb10060fb36a8d66efa9fa731dec (patch)
tree6f2f10be8bb7389cb2dda3e9e5c0cd7bff35c64f /candle-core
parent9cd54aa5d4fb6cf30e0df2d198c8a387db2d9144 (diff)
downloadcandle-25960676caefcb10060fb36a8d66efa9fa731dec.tar.gz
candle-25960676caefcb10060fb36a8d66efa9fa731dec.tar.bz2
candle-25960676caefcb10060fb36a8d66efa9fa731dec.zip
Add a basic metal example with capture (#2324)
* Add some tracing. * Get the trace to work.
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/Cargo.toml4
-rw-r--r--candle-core/examples/metal_basics.rs28
-rw-r--r--candle-core/src/metal_backend/device.rs8
3 files changed, 39 insertions, 1 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 92a04917..cbf8f200 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -48,3 +48,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "bench_main"
harness = false
+
+[[example]]
+name = "metal_basics"
+required-features = ["metal"]
diff --git a/candle-core/examples/metal_basics.rs b/candle-core/examples/metal_basics.rs
new file mode 100644
index 00000000..f9ff81ad
--- /dev/null
+++ b/candle-core/examples/metal_basics.rs
@@ -0,0 +1,28 @@
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+use anyhow::Result;
+use candle_core::{Device, Tensor};
+
+fn main() -> Result<()> {
+ // This requires the code to be run with MTL_CAPTURE_ENABLED=1
+ let device = Device::new_metal(0)?;
+ let metal_device = match &device {
+ Device::Metal(m) => m,
+ _ => anyhow::bail!("unexpected device"),
+ };
+ metal_device.capture("/tmp/candle.gputrace")?;
+ // This first synchronize ensures that a new command buffer gets created after setting up the
+ // capture scope.
+ device.synchronize()?;
+ let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
+ let x1 = x.add(&x)?;
+ println!("{x1:?}");
+ // This second synchronize ensures that the command buffer gets commited before the end of the
+ // capture scope.
+ device.synchronize()?;
+ Ok(())
+}
diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs
index 785fe621..07210c68 100644
--- a/candle-core/src/metal_backend/device.rs
+++ b/candle-core/src/metal_backend/device.rs
@@ -273,7 +273,13 @@ impl MetalDevice {
let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
descriptor.set_capture_device(self);
- descriptor.set_output_url(path);
+ // The [set_output_url] call requires an absolute path so we convert it if needed.
+ if path.as_ref().is_absolute() {
+ descriptor.set_output_url(path);
+ } else {
+ let path = std::env::current_dir()?.join(path);
+ descriptor.set_output_url(path);
+ }
capture
.start_capture(&descriptor)