diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-04 11:16:24 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-04 12:16:24 +0200 |
commit | 89eae41efdfde43080cc21a1ba194d61806e06da (patch) | |
tree | afe90a6ad436ee961bbd2d917976c055d6f98fb1 /candle-examples/examples/flux/main.rs | |
parent | c0a559d427c04c0484c14f6052b2ea268af10c9d (diff) | |
download | candle-89eae41efdfde43080cc21a1ba194d61806e06da.tar.gz candle-89eae41efdfde43080cc21a1ba194d61806e06da.tar.bz2 candle-89eae41efdfde43080cc21a1ba194d61806e06da.zip |
Support the flux-dev model too. (#2395)
Diffstat (limited to 'candle-examples/examples/flux/main.rs')
-rw-r--r-- | candle-examples/examples/flux/main.rs | 46 |
1 files changed, 37 insertions, 9 deletions
diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs index 826174bc..a9278d01 100644 --- a/candle-examples/examples/flux/main.rs +++ b/candle-examples/examples/flux/main.rs @@ -37,6 +37,15 @@ struct Args { #[arg(long)] decode_only: Option<String>, + + #[arg(long, value_enum, default_value = "schnell")] + model: Model, +} + +#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] +enum Model { + Schnell, + Dev, } fn run(args: Args) -> Result<()> { @@ -50,6 +59,7 @@ fn run(args: Args) -> Result<()> { width, tracing, decode_only, + model, } = args; let width = width.unwrap_or(1360); let height = height.unwrap_or(768); @@ -63,9 +73,13 @@ fn run(args: Args) -> Result<()> { }; let api = hf_hub::api::sync::Api::new()?; - let bf_repo = api.repo(hf_hub::Repo::model( - "black-forest-labs/FLUX.1-schnell".to_string(), - )); + let bf_repo = { + let name = match model { + Model::Dev => "black-forest-labs/FLUX.1-dev", + Model::Schnell => "black-forest-labs/FLUX.1-schnell", + }; + api.repo(hf_hub::Repo::model(name.to_string())) + }; let device = candle_examples::device(cpu)?; let dtype = device.bf16_default_to_f32(); let img = match decode_only { @@ -132,16 +146,27 @@ fn run(args: Args) -> Result<()> { }; println!("CLIP\n{clip_emb}"); let img = { - let model_file = bf_repo.get("flux1-schnell.sft")?; + let model_file = match model { + Model::Schnell => bf_repo.get("flux1-schnell.sft")?, + Model::Dev => bf_repo.get("flux1-dev.sft")?, + }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; - let cfg = flux::model::Config::schnell(); - let model = flux::model::Flux::new(&cfg, vb)?; - + let cfg = match model { + Model::Dev => flux::model::Config::dev(), + Model::Schnell => flux::model::Config::schnell(), + }; let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?; let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?; + let timesteps = match model { + Model::Dev => { + flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15))) + } + Model::Schnell => flux::sampling::get_schedule(4, None), + }; + let model = flux::model::Flux::new(&cfg, vb)?; + println!("{state:?}"); - let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell println!("{timesteps:?}"); flux::sampling::denoise( &model, @@ -166,7 +191,10 @@ fn run(args: Args) -> Result<()> { let img = { let model_file = bf_repo.get("ae.sft")?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; - let cfg = flux::autoencoder::Config::schnell(); + let cfg = match model { + Model::Dev => flux::autoencoder::Config::dev(), + Model::Schnell => flux::autoencoder::Config::schnell(), + }; let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?; model.decode(&img)? }; |