diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-26 10:23:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-26 10:23:43 +0200 |
commit | 10d47183c088ce449da13d74f07171c8106cd6dd (patch) | |
tree | b91b0398fcb314e998b9f7f3b23877f63462b232 /candle-examples/examples | |
parent | d01207dbf3fb0ad614e7915c8f5706fbc09902fb (diff) | |
download | candle-10d47183c088ce449da13d74f07171c8106cd6dd.tar.gz candle-10d47183c088ce449da13d74f07171c8106cd6dd.tar.bz2 candle-10d47183c088ce449da13d74f07171c8106cd6dd.zip |
Quantized version of flux. (#2500)
* Quantized version of flux.
* More generic sampling.
* Hook the quantized model.
* Use the newly minted gguf file.
* Fix for the quantized model.
* Default to avoid the faster cuda kernels.
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/flux/README.md | 2 | ||||
-rw-r--r-- | candle-examples/examples/flux/main.rs | 83 |
2 files changed, 65 insertions, 20 deletions
diff --git a/candle-examples/examples/flux/README.md b/candle-examples/examples/flux/README.md index 528f058e..dfc8ad5f 100644 --- a/candle-examples/examples/flux/README.md +++ b/candle-examples/examples/flux/README.md @@ -13,7 +13,7 @@ descriptions, ```bash cargo run --features cuda --example flux -r -- \ - --height 1024 --width 1024 + --height 1024 --width 1024 \ --prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k" ``` diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs index 539ae6f2..24b1fa2b 100644 --- a/candle-examples/examples/flux/main.rs +++ b/candle-examples/examples/flux/main.rs @@ -23,6 +23,10 @@ struct Args { #[arg(long)] cpu: bool, + /// Use the quantized model. + #[arg(long)] + quantized: bool, + /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -40,6 +44,10 @@ struct Args { #[arg(long, value_enum, default_value = "schnell")] model: Model, + + /// Use the faster kernels which are buggy at the moment. + #[arg(long)] + no_dmmv: bool, } #[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] @@ -60,6 +68,8 @@ fn run(args: Args) -> Result<()> { tracing, decode_only, model, + quantized, + .. } = args; let width = width.unwrap_or(1360); let height = height.unwrap_or(768); @@ -146,38 +156,71 @@ fn run(args: Args) -> Result<()> { }; println!("CLIP\n{clip_emb}"); let img = { - let model_file = match model { - Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?, - Model::Dev => bf_repo.get("flux1-dev.safetensors")?, - }; - let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; let cfg = match model { Model::Dev => flux::model::Config::dev(), Model::Schnell => flux::model::Config::schnell(), }; let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?; - let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?; + let state = if quantized { + flux::sampling::State::new( + &t5_emb.to_dtype(candle::DType::F32)?, + &clip_emb.to_dtype(candle::DType::F32)?, + &img.to_dtype(candle::DType::F32)?, + )? + } else { + flux::sampling::State::new(&t5_emb, &clip_emb, &img)? + }; let timesteps = match model { Model::Dev => { flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15))) } Model::Schnell => flux::sampling::get_schedule(4, None), }; - let model = flux::model::Flux::new(&cfg, vb)?; - println!("{state:?}"); println!("{timesteps:?}"); - flux::sampling::denoise( - &model, - &state.img, - &state.img_ids, - &state.txt, - &state.txt_ids, - &state.vec, - ×teps, - 4., - )? + if quantized { + let model_file = match model { + Model::Schnell => api + .repo(hf_hub::Repo::model("lmz/candle-flux".to_string())) + .get("flux1-schnell.gguf")?, + Model::Dev => todo!(), + }; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + model_file, &device, + )?; + + let model = flux::quantized_model::Flux::new(&cfg, vb)?; + flux::sampling::denoise( + &model, + &state.img, + &state.img_ids, + &state.txt, + &state.txt_ids, + &state.vec, + ×teps, + 4., + )? + .to_dtype(dtype)? + } else { + let model_file = match model { + Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?, + Model::Dev => bf_repo.get("flux1-dev.safetensors")?, + }; + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? + }; + let model = flux::model::Flux::new(&cfg, vb)?; + flux::sampling::denoise( + &model, + &state.img, + &state.img_ids, + &state.txt, + &state.txt_ids, + &state.vec, + ×teps, + 4., + )? + } }; flux::sampling::unpack(&img, height, width)? } @@ -206,5 +249,7 @@ fn run(args: Args) -> Result<()> { fn main() -> Result<()> { let args = Args::parse(); + #[cfg(feature = "cuda")] + candle::quantized::cuda::set_force_dmmv(!args.no_dmmv); run(args) } |