summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion-3/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion-3/main.rs')
-rw-r--r--candle-examples/examples/stable-diffusion-3/main.rs33
1 files changed, 17 insertions, 16 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs
index 702d8eec..01b09101 100644
--- a/candle-examples/examples/stable-diffusion-3/main.rs
+++ b/candle-examples/examples/stable-diffusion-3/main.rs
@@ -183,26 +183,27 @@ fn main() -> Result<()> {
let context = Tensor::cat(&[context, context_uncond], 0)?;
let y = Tensor::cat(&[y, y_uncond], 0)?;
- let mmdit = MMDiT::new(
- &mmdit_config,
- use_flash_attn,
- vb.pp("model.diffusion_model"),
- )?;
-
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let start_time = std::time::Instant::now();
- let x = sampling::euler_sample(
- &mmdit,
- &y,
- &context,
- num_inference_steps,
- cfg_scale,
- time_shift,
- height,
- width,
- )?;
+ let x = {
+ let mmdit = MMDiT::new(
+ &mmdit_config,
+ use_flash_attn,
+ vb.pp("model.diffusion_model"),
+ )?;
+ sampling::euler_sample(
+ &mmdit,
+ &y,
+ &context,
+ num_inference_steps,
+ cfg_scale,
+ time_shift,
+ height,
+ width,
+ )?
+ };
let dt = start_time.elapsed().as_secs_f32();
println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",