diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion-3/main.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/main.rs | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs index 164ae420..ee467839 100644 --- a/candle-examples/examples/stable-diffusion-3/main.rs +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -30,9 +30,9 @@ struct Args { #[arg(long)] cpu: bool, - /// The CUDA device ID to use. - #[arg(long, default_value = "0")] - cuda_device_id: usize, + /// The GPU device ID to use. + #[arg(long, default_value_t = 0)] + gpu_device_id: usize, /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] @@ -81,7 +81,7 @@ fn run(args: Args) -> Result<()> { prompt, uncond_prompt, cpu, - cuda_device_id, + gpu_device_id, tracing, use_flash_attn, height, @@ -100,11 +100,14 @@ fn run(args: Args) -> Result<()> { None }; - // TODO: Support and test on Metal. let device = if cpu { candle::Device::Cpu + } else if candle::utils::cuda_is_available() { + candle::Device::new_cuda(gpu_device_id)? + } else if candle::utils::metal_is_available() { + candle::Device::new_metal(gpu_device_id)? } else { - candle::Device::cuda_if_available(cuda_device_id)? + candle::Device::Cpu }; let api = hf_hub::api::sync::Api::new()?; |