summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-26 10:23:43 +0200
committerGitHub <noreply@github.com>2024-09-26 10:23:43 +0200
commit10d47183c088ce449da13d74f07171c8106cd6dd (patch)
treeb91b0398fcb314e998b9f7f3b23877f63462b232 /candle-examples/examples
parentd01207dbf3fb0ad614e7915c8f5706fbc09902fb (diff)
downloadcandle-10d47183c088ce449da13d74f07171c8106cd6dd.tar.gz
candle-10d47183c088ce449da13d74f07171c8106cd6dd.tar.bz2
candle-10d47183c088ce449da13d74f07171c8106cd6dd.zip
Quantized version of flux. (#2500)
* Quantized version of flux. * More generic sampling. * Hook the quantized model. * Use the newly minted gguf file. * Fix for the quantized model. * Default to avoid the faster cuda kernels.
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/flux/README.md2
-rw-r--r--candle-examples/examples/flux/main.rs83
2 files changed, 65 insertions, 20 deletions
diff --git a/candle-examples/examples/flux/README.md b/candle-examples/examples/flux/README.md
index 528f058e..dfc8ad5f 100644
--- a/candle-examples/examples/flux/README.md
+++ b/candle-examples/examples/flux/README.md
@@ -13,7 +13,7 @@ descriptions,
```bash
cargo run --features cuda --example flux -r -- \
- --height 1024 --width 1024
+ --height 1024 --width 1024 \
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
```
diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs
index 539ae6f2..24b1fa2b 100644
--- a/candle-examples/examples/flux/main.rs
+++ b/candle-examples/examples/flux/main.rs
@@ -23,6 +23,10 @@ struct Args {
#[arg(long)]
cpu: bool,
+ /// Use the quantized model.
+ #[arg(long)]
+ quantized: bool,
+
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@@ -40,6 +44,10 @@ struct Args {
#[arg(long, value_enum, default_value = "schnell")]
model: Model,
+
+ /// Use the faster kernels which are buggy at the moment.
+ #[arg(long)]
+ no_dmmv: bool,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
@@ -60,6 +68,8 @@ fn run(args: Args) -> Result<()> {
tracing,
decode_only,
model,
+ quantized,
+ ..
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
@@ -146,38 +156,71 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
- let model_file = match model {
- Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
- Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
- };
- let vb =
- unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
- let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
+ let state = if quantized {
+ flux::sampling::State::new(
+ &t5_emb.to_dtype(candle::DType::F32)?,
+ &clip_emb.to_dtype(candle::DType::F32)?,
+ &img.to_dtype(candle::DType::F32)?,
+ )?
+ } else {
+ flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
+ };
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
- let model = flux::model::Flux::new(&cfg, vb)?;
-
println!("{state:?}");
println!("{timesteps:?}");
- flux::sampling::denoise(
- &model,
- &state.img,
- &state.img_ids,
- &state.txt,
- &state.txt_ids,
- &state.vec,
- &timesteps,
- 4.,
- )?
+ if quantized {
+ let model_file = match model {
+ Model::Schnell => api
+ .repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
+ .get("flux1-schnell.gguf")?,
+ Model::Dev => todo!(),
+ };
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
+ model_file, &device,
+ )?;
+
+ let model = flux::quantized_model::Flux::new(&cfg, vb)?;
+ flux::sampling::denoise(
+ &model,
+ &state.img,
+ &state.img_ids,
+ &state.txt,
+ &state.txt_ids,
+ &state.vec,
+ &timesteps,
+ 4.,
+ )?
+ .to_dtype(dtype)?
+ } else {
+ let model_file = match model {
+ Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
+ Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
+ };
+ let vb = unsafe {
+ VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
+ };
+ let model = flux::model::Flux::new(&cfg, vb)?;
+ flux::sampling::denoise(
+ &model,
+ &state.img,
+ &state.img_ids,
+ &state.txt,
+ &state.txt_ids,
+ &state.vec,
+ &timesteps,
+ 4.,
+ )?
+ }
};
flux::sampling::unpack(&img, height, width)?
}
@@ -206,5 +249,7 @@ fn run(args: Args) -> Result<()> {
fn main() -> Result<()> {
let args = Args::parse();
+ #[cfg(feature = "cuda")]
+ candle::quantized::cuda::set_force_dmmv(!args.no_dmmv);
run(args)
}