summaryrefslogtreecommitdiff
path: root/candle-examples/examples/flux/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-08-04 11:16:24 +0100
committerGitHub <noreply@github.com>2024-08-04 12:16:24 +0200
commit89eae41efdfde43080cc21a1ba194d61806e06da (patch)
treeafe90a6ad436ee961bbd2d917976c055d6f98fb1 /candle-examples/examples/flux/main.rs
parentc0a559d427c04c0484c14f6052b2ea268af10c9d (diff)
downloadcandle-89eae41efdfde43080cc21a1ba194d61806e06da.tar.gz
candle-89eae41efdfde43080cc21a1ba194d61806e06da.tar.bz2
candle-89eae41efdfde43080cc21a1ba194d61806e06da.zip
Support the flux-dev model too. (#2395)
Diffstat (limited to 'candle-examples/examples/flux/main.rs')
-rw-r--r--candle-examples/examples/flux/main.rs46
1 files changed, 37 insertions, 9 deletions
diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs
index 826174bc..a9278d01 100644
--- a/candle-examples/examples/flux/main.rs
+++ b/candle-examples/examples/flux/main.rs
@@ -37,6 +37,15 @@ struct Args {
#[arg(long)]
decode_only: Option<String>,
+
+ #[arg(long, value_enum, default_value = "schnell")]
+ model: Model,
+}
+
+#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
+enum Model {
+ Schnell,
+ Dev,
}
fn run(args: Args) -> Result<()> {
@@ -50,6 +59,7 @@ fn run(args: Args) -> Result<()> {
width,
tracing,
decode_only,
+ model,
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
@@ -63,9 +73,13 @@ fn run(args: Args) -> Result<()> {
};
let api = hf_hub::api::sync::Api::new()?;
- let bf_repo = api.repo(hf_hub::Repo::model(
- "black-forest-labs/FLUX.1-schnell".to_string(),
- ));
+ let bf_repo = {
+ let name = match model {
+ Model::Dev => "black-forest-labs/FLUX.1-dev",
+ Model::Schnell => "black-forest-labs/FLUX.1-schnell",
+ };
+ api.repo(hf_hub::Repo::model(name.to_string()))
+ };
let device = candle_examples::device(cpu)?;
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
@@ -132,16 +146,27 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
- let model_file = bf_repo.get("flux1-schnell.sft")?;
+ let model_file = match model {
+ Model::Schnell => bf_repo.get("flux1-schnell.sft")?,
+ Model::Dev => bf_repo.get("flux1-dev.sft")?,
+ };
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
- let cfg = flux::model::Config::schnell();
- let model = flux::model::Flux::new(&cfg, vb)?;
-
+ let cfg = match model {
+ Model::Dev => flux::model::Config::dev(),
+ Model::Schnell => flux::model::Config::schnell(),
+ };
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
+ let timesteps = match model {
+ Model::Dev => {
+ flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
+ }
+ Model::Schnell => flux::sampling::get_schedule(4, None),
+ };
+ let model = flux::model::Flux::new(&cfg, vb)?;
+
println!("{state:?}");
- let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell
println!("{timesteps:?}");
flux::sampling::denoise(
&model,
@@ -166,7 +191,10 @@ fn run(args: Args) -> Result<()> {
let img = {
let model_file = bf_repo.get("ae.sft")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
- let cfg = flux::autoencoder::Config::schnell();
+ let cfg = match model {
+ Model::Dev => flux::autoencoder::Config::dev(),
+ Model::Schnell => flux::autoencoder::Config::schnell(),
+ };
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
model.decode(&img)?
};