summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion-3/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion-3/main.rs')
-rw-r--r--candle-examples/examples/stable-diffusion-3/main.rs15
1 files changed, 9 insertions, 6 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs
index 164ae420..ee467839 100644
--- a/candle-examples/examples/stable-diffusion-3/main.rs
+++ b/candle-examples/examples/stable-diffusion-3/main.rs
@@ -30,9 +30,9 @@ struct Args {
#[arg(long)]
cpu: bool,
- /// The CUDA device ID to use.
- #[arg(long, default_value = "0")]
- cuda_device_id: usize,
+ /// The GPU device ID to use.
+ #[arg(long, default_value_t = 0)]
+ gpu_device_id: usize,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
@@ -81,7 +81,7 @@ fn run(args: Args) -> Result<()> {
prompt,
uncond_prompt,
cpu,
- cuda_device_id,
+ gpu_device_id,
tracing,
use_flash_attn,
height,
@@ -100,11 +100,14 @@ fn run(args: Args) -> Result<()> {
None
};
- // TODO: Support and test on Metal.
let device = if cpu {
candle::Device::Cpu
+ } else if candle::utils::cuda_is_available() {
+ candle::Device::new_cuda(gpu_device_id)?
+ } else if candle::utils::metal_is_available() {
+ candle::Device::new_metal(gpu_device_id)?
} else {
- candle::Device::cuda_if_available(cuda_device_id)?
+ candle::Device::Cpu
};
let api = hf_hub::api::sync::Api::new()?;