diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-10-14 08:59:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-14 08:59:12 +0200 |
commit | 3d1dc06cdb44e2e012559aadd8da7342da9c2ed5 (patch) | |
tree | 216941d76c743e100409069f2c9cded4f4ea6392 /candle-examples | |
parent | f553ab5eb401cc3e1588db7fe987aae37f65d113 (diff) | |
download | candle-3d1dc06cdb44e2e012559aadd8da7342da9c2ed5.tar.gz candle-3d1dc06cdb44e2e012559aadd8da7342da9c2ed5.tar.bz2 candle-3d1dc06cdb44e2e012559aadd8da7342da9c2ed5.zip |
Enable stable-diffusion 3 on metal. (#2560)
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/Cargo.toml | 3 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/main.rs | 15 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/sampling.rs | 2 |
3 files changed, 10 insertions, 10 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index d3e23b92..0c1219d7 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -122,6 +122,3 @@ required-features = ["onnx"] [[example]] name = "colpali" required-features = ["pdf2image"] - -[[example]] -name = "stable-diffusion-3"
\ No newline at end of file diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs index 164ae420..ee467839 100644 --- a/candle-examples/examples/stable-diffusion-3/main.rs +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -30,9 +30,9 @@ struct Args { #[arg(long)] cpu: bool, - /// The CUDA device ID to use. - #[arg(long, default_value = "0")] - cuda_device_id: usize, + /// The GPU device ID to use. + #[arg(long, default_value_t = 0)] + gpu_device_id: usize, /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] @@ -81,7 +81,7 @@ fn run(args: Args) -> Result<()> { prompt, uncond_prompt, cpu, - cuda_device_id, + gpu_device_id, tracing, use_flash_attn, height, @@ -100,11 +100,14 @@ fn run(args: Args) -> Result<()> { None }; - // TODO: Support and test on Metal. let device = if cpu { candle::Device::Cpu + } else if candle::utils::cuda_is_available() { + candle::Device::new_cuda(gpu_device_id)? + } else if candle::utils::metal_is_available() { + candle::Device::new_metal(gpu_device_id)? } else { - candle::Device::cuda_if_available(cuda_device_id)? + candle::Device::Cpu }; let api = hf_hub::api::sync::Api::new()?; diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs index 147d8e73..0efd160e 100644 --- a/candle-examples/examples/stable-diffusion-3/sampling.rs +++ b/candle-examples/examples/stable-diffusion-3/sampling.rs @@ -31,7 +31,7 @@ pub fn euler_sample( let timestep = (*s_curr) * 1000.0; let noise_pred = mmdit.forward( &Tensor::cat(&[x.clone(), x.clone()], 0)?, - &Tensor::full(timestep, (2,), x.device())?.contiguous()?, + &Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?, y, context, )?; |