summaryrefslogtreecommitdiff
path: root/candle-examples/examples/wuerstchen/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/wuerstchen/main.rs')
-rw-r--r--candle-examples/examples/wuerstchen/main.rs14
1 files changed, 12 insertions, 2 deletions
diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs
index aaa9b78a..95f3b8f4 100644
--- a/candle-examples/examples/wuerstchen/main.rs
+++ b/candle-examples/examples/wuerstchen/main.rs
@@ -41,6 +41,9 @@ struct Args {
#[arg(long)]
tracing: bool,
+ #[arg(long)]
+ use_flash_attn: bool,
+
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
@@ -289,8 +292,14 @@ fn run(args: Args) -> Result<()> {
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::prior::WPrior::new(
- /* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
- /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
+ /* c_in */ PRIOR_CIN,
+ /* c */ 1536,
+ /* c_cond */ 1280,
+ /* c_r */ 64,
+ /* depth */ 32,
+ /* nhead */ 24,
+ args.use_flash_attn,
+ vb,
)?
};
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
@@ -337,6 +346,7 @@ fn run(args: Args) -> Result<()> {
/* c_cond */ 1024,
/* clip_embd */ 1024,
/* patch_size */ 2,
+ args.use_flash_attn,
vb,
)?
};