summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/Cargo.toml3
-rw-r--r--candle-examples/examples/stable-diffusion-3/main.rs15
-rw-r--r--candle-examples/examples/stable-diffusion-3/sampling.rs2
-rw-r--r--candle-transformers/src/models/marian.rs3
4 files changed, 11 insertions, 12 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index d3e23b92..0c1219d7 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -122,6 +122,3 @@ required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]
-
-[[example]]
-name = "stable-diffusion-3" \ No newline at end of file
diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs
index 164ae420..ee467839 100644
--- a/candle-examples/examples/stable-diffusion-3/main.rs
+++ b/candle-examples/examples/stable-diffusion-3/main.rs
@@ -30,9 +30,9 @@ struct Args {
#[arg(long)]
cpu: bool,
- /// The CUDA device ID to use.
- #[arg(long, default_value = "0")]
- cuda_device_id: usize,
+ /// The GPU device ID to use.
+ #[arg(long, default_value_t = 0)]
+ gpu_device_id: usize,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
@@ -81,7 +81,7 @@ fn run(args: Args) -> Result<()> {
prompt,
uncond_prompt,
cpu,
- cuda_device_id,
+ gpu_device_id,
tracing,
use_flash_attn,
height,
@@ -100,11 +100,14 @@ fn run(args: Args) -> Result<()> {
None
};
- // TODO: Support and test on Metal.
let device = if cpu {
candle::Device::Cpu
+ } else if candle::utils::cuda_is_available() {
+ candle::Device::new_cuda(gpu_device_id)?
+ } else if candle::utils::metal_is_available() {
+ candle::Device::new_metal(gpu_device_id)?
} else {
- candle::Device::cuda_if_available(cuda_device_id)?
+ candle::Device::Cpu
};
let api = hf_hub::api::sync::Api::new()?;
diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs
index 147d8e73..0efd160e 100644
--- a/candle-examples/examples/stable-diffusion-3/sampling.rs
+++ b/candle-examples/examples/stable-diffusion-3/sampling.rs
@@ -31,7 +31,7 @@ pub fn euler_sample(
let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward(
&Tensor::cat(&[x.clone(), x.clone()], 0)?,
- &Tensor::full(timestep, (2,), x.device())?.contiguous()?,
+ &Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
)?;
diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs
index c4299da6..e93370c2 100644
--- a/candle-transformers/src/models/marian.rs
+++ b/candle-transformers/src/models/marian.rs
@@ -1,9 +1,8 @@
use super::with_tracing::{linear, Embedding, Linear};
use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
-use serde::Deserialize;
-#[derive(Debug, Clone, Deserialize)]
+#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub decoder_vocab_size: Option<usize>,