diff options
Diffstat (limited to 'candle-examples/examples/wuerstchen/main.rs')
-rw-r--r-- | candle-examples/examples/wuerstchen/main.rs | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index aaa9b78a..95f3b8f4 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -41,6 +41,9 @@ struct Args { #[arg(long)] tracing: bool, + #[arg(long)] + use_flash_attn: bool, + /// The height in pixels of the generated image. #[arg(long)] height: Option<usize>, @@ -289,8 +292,14 @@ fn run(args: Args) -> Result<()> { let weights = weights.deserialize()?; let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); wuerstchen::prior::WPrior::new( - /* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280, - /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb, + /* c_in */ PRIOR_CIN, + /* c */ 1536, + /* c_cond */ 1280, + /* c_r */ 64, + /* depth */ 32, + /* nhead */ 24, + args.use_flash_attn, + vb, )? }; let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; @@ -337,6 +346,7 @@ fn run(args: Args) -> Result<()> { /* c_cond */ 1024, /* clip_embd */ 1024, /* patch_size */ 2, + args.use_flash_attn, vb, )? }; |