summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/flux
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-08-04 07:14:33 +0100
committerGitHub <noreply@github.com>2024-08-04 08:14:33 +0200
commit19db6b97236f4d649253ab277eebc1f3041eeff8 (patch)
tree099c540f02e3247edc4bbf2b286cc6a91dbe00c0 /candle-transformers/src/models/flux
parent0fcb40b229a3fc627cdc86513560d2c917b39550 (diff)
downloadcandle-19db6b97236f4d649253ab277eebc1f3041eeff8.tar.gz
candle-19db6b97236f4d649253ab277eebc1f3041eeff8.tar.bz2
candle-19db6b97236f4d649253ab277eebc1f3041eeff8.zip
Add the flux model for image generation. (#2390)
* Add the flux autoencoder. * Add the encoder down-blocks. * Upsampling in the decoder. * Sketch the flow matching model. * More flux model. * Add some of the positional embeddings. * Add the rope embeddings. * Add the sampling functions. * Add the flux example. * Fix the T5 bits. * Proper T5 tokenizer. * Clip encoder path fix. * Get the clip embeddings. * No configurable weights in layer norm. * More weights related fixes. * Yet another shape fix. * DType fix. * Fix a couple more shape issues. * DType fixes. * Fix the latent dims. * Fix more shape issues. * Autoencoder fixes. * Get some generations out. * Bugfix. * T5 padding. * Clippy fix. * Add the decode only mode. * Fix. * More fixes. * Finally get some generations to work. * Add readme.
Diffstat (limited to 'candle-transformers/src/models/flux')
-rw-r--r--candle-transformers/src/models/flux/autoencoder.rs440
-rw-r--r--candle-transformers/src/models/flux/mod.rs3
-rw-r--r--candle-transformers/src/models/flux/model.rs582
-rw-r--r--candle-transformers/src/models/flux/sampling.rs119
4 files changed, 1144 insertions, 0 deletions
diff --git a/candle-transformers/src/models/flux/autoencoder.rs b/candle-transformers/src/models/flux/autoencoder.rs
new file mode 100644
index 00000000..8c2aebbd
--- /dev/null
+++ b/candle-transformers/src/models/flux/autoencoder.rs
@@ -0,0 +1,440 @@
+use candle::{Result, Tensor, D};
+use candle_nn::{conv2d, group_norm, Conv2d, GroupNorm, VarBuilder};
+
+// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/modules/autoencoder.py#L9
+#[derive(Debug, Clone)]
+pub struct Config {
+ pub resolution: usize,
+ pub in_channels: usize,
+ pub ch: usize,
+ pub out_ch: usize,
+ pub ch_mult: Vec<usize>,
+ pub num_res_blocks: usize,
+ pub z_channels: usize,
+ pub scale_factor: f64,
+ pub shift_factor: f64,
+}
+
+impl Config {
+ // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L47
+ pub fn dev() -> Self {
+ Self {
+ resolution: 256,
+ in_channels: 3,
+ ch: 128,
+ out_ch: 3,
+ ch_mult: vec![1, 2, 4, 4],
+ num_res_blocks: 2,
+ z_channels: 16,
+ scale_factor: 0.3611,
+ shift_factor: 0.1159,
+ }
+ }
+
+ // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L79
+ pub fn schnell() -> Self {
+ Self {
+ resolution: 256,
+ in_channels: 3,
+ ch: 128,
+ out_ch: 3,
+ ch_mult: vec![1, 2, 4, 4],
+ num_res_blocks: 2,
+ z_channels: 16,
+ scale_factor: 0.3611,
+ shift_factor: 0.1159,
+ }
+ }
+}
+
+fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
+ let dim = q.dim(D::Minus1)?;
+ let scale_factor = 1.0 / (dim as f64).sqrt();
+ let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
+ candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
+}
+
+#[derive(Debug, Clone)]
+struct AttnBlock {
+ q: Conv2d,
+ k: Conv2d,
+ v: Conv2d,
+ proj_out: Conv2d,
+ norm: GroupNorm,
+}
+
+impl AttnBlock {
+ fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
+ let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?;
+ let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?;
+ let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?;
+ let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?;
+ let norm = group_norm(32, in_c, 1e-6, vb.pp("norm"))?;
+ Ok(Self {
+ q,
+ k,
+ v,
+ proj_out,
+ norm,
+ })
+ }
+}
+
+impl candle::Module for AttnBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let init_xs = xs;
+ let xs = xs.apply(&self.norm)?;
+ let q = xs.apply(&self.q)?;
+ let k = xs.apply(&self.k)?;
+ let v = xs.apply(&self.v)?;
+ let (b, c, h, w) = q.dims4()?;
+ let q = q.flatten_from(2)?.t()?.unsqueeze(1)?;
+ let k = k.flatten_from(2)?.t()?.unsqueeze(1)?;
+ let v = v.flatten_from(2)?.t()?.unsqueeze(1)?;
+ let xs = scaled_dot_product_attention(&q, &k, &v)?;
+ let xs = xs.squeeze(1)?.t()?.reshape((b, c, h, w))?;
+ xs.apply(&self.proj_out)? + init_xs
+ }
+}
+
+#[derive(Debug, Clone)]
+struct ResnetBlock {
+ norm1: GroupNorm,
+ conv1: Conv2d,
+ norm2: GroupNorm,
+ conv2: Conv2d,
+ nin_shortcut: Option<Conv2d>,
+}
+
+impl ResnetBlock {
+ fn new(in_c: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
+ let conv_cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let norm1 = group_norm(32, in_c, 1e-6, vb.pp("norm1"))?;
+ let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?;
+ let norm2 = group_norm(32, out_c, 1e-6, vb.pp("norm2"))?;
+ let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?;
+ let nin_shortcut = if in_c == out_c {
+ None
+ } else {
+ Some(conv2d(
+ in_c,
+ out_c,
+ 1,
+ Default::default(),
+ vb.pp("nin_shortcut"),
+ )?)
+ };
+ Ok(Self {
+ norm1,
+ conv1,
+ norm2,
+ conv2,
+ nin_shortcut,
+ })
+ }
+}
+
+impl candle::Module for ResnetBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let h = xs
+ .apply(&self.norm1)?
+ .apply(&candle_nn::Activation::Swish)?
+ .apply(&self.conv1)?
+ .apply(&self.norm2)?
+ .apply(&candle_nn::Activation::Swish)?
+ .apply(&self.conv2)?;
+ match self.nin_shortcut.as_ref() {
+ None => xs + h,
+ Some(c) => xs.apply(c)? + h,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Downsample {
+ conv: Conv2d,
+}
+
+impl Downsample {
+ fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
+ let conv_cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
+ Ok(Self { conv })
+ }
+}
+
+impl candle::Module for Downsample {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
+ let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
+ xs.apply(&self.conv)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Upsample {
+ conv: Conv2d,
+}
+
+impl Upsample {
+ fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
+ let conv_cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
+ Ok(Self { conv })
+ }
+}
+
+impl candle::Module for Upsample {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (_, _, h, w) = xs.dims4()?;
+ xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DownBlock {
+ block: Vec<ResnetBlock>,
+ downsample: Option<Downsample>,
+}
+
+#[derive(Debug, Clone)]
+pub struct Encoder {
+ conv_in: Conv2d,
+ mid_block_1: ResnetBlock,
+ mid_attn_1: AttnBlock,
+ mid_block_2: ResnetBlock,
+ norm_out: GroupNorm,
+ conv_out: Conv2d,
+ down: Vec<DownBlock>,
+}
+
+impl Encoder {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let conv_cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let mut block_in = cfg.ch;
+ let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
+
+ let mut down = Vec::with_capacity(cfg.ch_mult.len());
+ let vb_d = vb.pp("down");
+ for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate() {
+ let mut block = Vec::with_capacity(cfg.num_res_blocks);
+ let vb_d = vb_d.pp(i_level);
+ let vb_b = vb_d.pp("block");
+ let in_ch_mult = if i_level == 0 {
+ 1
+ } else {
+ cfg.ch_mult[i_level - 1]
+ };
+ block_in = cfg.ch * in_ch_mult;
+ let block_out = cfg.ch * ch_mult;
+ for i_block in 0..cfg.num_res_blocks {
+ let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
+ block.push(b);
+ block_in = block_out;
+ }
+ let downsample = if i_level != cfg.ch_mult.len() - 1 {
+ Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
+ } else {
+ None
+ };
+ let block = DownBlock { block, downsample };
+ down.push(block)
+ }
+
+ let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
+ let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
+ let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
+ let conv_out = conv2d(block_in, 2 * cfg.z_channels, 3, conv_cfg, vb.pp("conv_out"))?;
+ let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
+ Ok(Self {
+ conv_in,
+ mid_block_1,
+ mid_attn_1,
+ mid_block_2,
+ norm_out,
+ conv_out,
+ down,
+ })
+ }
+}
+
+impl candle_nn::Module for Encoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut h = xs.apply(&self.conv_in)?;
+ for block in self.down.iter() {
+ for b in block.block.iter() {
+ h = h.apply(b)?
+ }
+ if let Some(ds) = block.downsample.as_ref() {
+ h = h.apply(ds)?
+ }
+ }
+ h.apply(&self.mid_block_1)?
+ .apply(&self.mid_attn_1)?
+ .apply(&self.mid_block_2)?
+ .apply(&self.norm_out)?
+ .apply(&candle_nn::Activation::Swish)?
+ .apply(&self.conv_out)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct UpBlock {
+ block: Vec<ResnetBlock>,
+ upsample: Option<Upsample>,
+}
+
+#[derive(Debug, Clone)]
+pub struct Decoder {
+ conv_in: Conv2d,
+ mid_block_1: ResnetBlock,
+ mid_attn_1: AttnBlock,
+ mid_block_2: ResnetBlock,
+ norm_out: GroupNorm,
+ conv_out: Conv2d,
+ up: Vec<UpBlock>,
+}
+
+impl Decoder {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let conv_cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let mut block_in = cfg.ch * cfg.ch_mult.last().unwrap_or(&1);
+ let conv_in = conv2d(cfg.z_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
+ let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
+ let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
+ let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
+
+ let mut up = Vec::with_capacity(cfg.ch_mult.len());
+ let vb_u = vb.pp("up");
+ for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate().rev() {
+ let block_out = cfg.ch * ch_mult;
+ let vb_u = vb_u.pp(i_level);
+ let vb_b = vb_u.pp("block");
+ let mut block = Vec::with_capacity(cfg.num_res_blocks + 1);
+ for i_block in 0..=cfg.num_res_blocks {
+ let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
+ block.push(b);
+ block_in = block_out;
+ }
+ let upsample = if i_level != 0 {
+ Some(Upsample::new(block_in, vb_u.pp("upsample"))?)
+ } else {
+ None
+ };
+ let block = UpBlock { block, upsample };
+ up.push(block)
+ }
+ up.reverse();
+
+ let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
+ let conv_out = conv2d(block_in, cfg.out_ch, 3, conv_cfg, vb.pp("conv_out"))?;
+ Ok(Self {
+ conv_in,
+ mid_block_1,
+ mid_attn_1,
+ mid_block_2,
+ norm_out,
+ conv_out,
+ up,
+ })
+ }
+}
+
+impl candle_nn::Module for Decoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let h = xs.apply(&self.conv_in)?;
+ let mut h = h
+ .apply(&self.mid_block_1)?
+ .apply(&self.mid_attn_1)?
+ .apply(&self.mid_block_2)?;
+ for block in self.up.iter().rev() {
+ for b in block.block.iter() {
+ h = h.apply(b)?
+ }
+ if let Some(us) = block.upsample.as_ref() {
+ h = h.apply(us)?
+ }
+ }
+ h.apply(&self.norm_out)?
+ .apply(&candle_nn::Activation::Swish)?
+ .apply(&self.conv_out)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct DiagonalGaussian {
+ sample: bool,
+ chunk_dim: usize,
+}
+
+impl DiagonalGaussian {
+ pub fn new(sample: bool, chunk_dim: usize) -> Result<Self> {
+ Ok(Self { sample, chunk_dim })
+ }
+}
+
+impl candle_nn::Module for DiagonalGaussian {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let chunks = xs.chunk(2, self.chunk_dim)?;
+ if self.sample {
+ let std = (&chunks[1] * 0.5)?.exp()?;
+ &chunks[0] + (std * chunks[0].randn_like(0., 1.))?
+ } else {
+ Ok(chunks[0].clone())
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct AutoEncoder {
+ encoder: Encoder,
+ decoder: Decoder,
+ reg: DiagonalGaussian,
+ shift_factor: f64,
+ scale_factor: f64,
+}
+
+impl AutoEncoder {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
+ let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
+ let reg = DiagonalGaussian::new(true, 1)?;
+ Ok(Self {
+ encoder,
+ decoder,
+ reg,
+ scale_factor: cfg.scale_factor,
+ shift_factor: cfg.shift_factor,
+ })
+ }
+
+ pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
+ let z = xs.apply(&self.encoder)?.apply(&self.reg)?;
+ (z - self.shift_factor)? * self.scale_factor
+ }
+ pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
+ xs.apply(&self.decoder)
+ }
+}
+
+impl candle::Module for AutoEncoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.decode(&self.encode(xs)?)
+ }
+}
diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs
new file mode 100644
index 00000000..763fa90d
--- /dev/null
+++ b/candle-transformers/src/models/flux/mod.rs
@@ -0,0 +1,3 @@
+pub mod autoencoder;
+pub mod model;
+pub mod sampling;
diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs
new file mode 100644
index 00000000..aa00077e
--- /dev/null
+++ b/candle-transformers/src/models/flux/model.rs
@@ -0,0 +1,582 @@
+use candle::{DType, IndexOp, Result, Tensor, D};
+use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder};
+
+// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12
+#[derive(Debug, Clone)]
+pub struct Config {
+ pub in_channels: usize,
+ pub vec_in_dim: usize,
+ pub context_in_dim: usize,
+ pub hidden_size: usize,
+ pub mlp_ratio: f64,
+ pub num_heads: usize,
+ pub depth: usize,
+ pub depth_single_blocks: usize,
+ pub axes_dim: Vec<usize>,
+ pub theta: usize,
+ pub qkv_bias: bool,
+ pub guidance_embed: bool,
+}
+
+impl Config {
+ // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32
+ pub fn dev() -> Self {
+ Self {
+ in_channels: 64,
+ vec_in_dim: 768,
+ context_in_dim: 4096,
+ hidden_size: 3072,
+ mlp_ratio: 4.0,
+ num_heads: 24,
+ depth: 19,
+ depth_single_blocks: 38,
+ axes_dim: vec![16, 56, 56],
+ theta: 10_000,
+ qkv_bias: true,
+ guidance_embed: true,
+ }
+ }
+
+ // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64
+ pub fn schnell() -> Self {
+ Self {
+ in_channels: 64,
+ vec_in_dim: 768,
+ context_in_dim: 4096,
+ hidden_size: 3072,
+ mlp_ratio: 4.0,
+ num_heads: 24,
+ depth: 19,
+ depth_single_blocks: 38,
+ axes_dim: vec![16, 56, 56],
+ theta: 10_000,
+ qkv_bias: true,
+ guidance_embed: false,
+ }
+ }
+}
+
+fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
+ let ws = Tensor::ones(dim, vb.dtype(), vb.device())?;
+ Ok(LayerNorm::new_no_bias(ws, 1e-6))
+}
+
+fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
+ let dim = q.dim(D::Minus1)?;
+ let scale_factor = 1.0 / (dim as f64).sqrt();
+ let mut batch_dims = q.dims().to_vec();
+ batch_dims.pop();
+ batch_dims.pop();
+ let q = q.flatten_to(batch_dims.len() - 1)?;
+ let k = k.flatten_to(batch_dims.len() - 1)?;
+ let v = v.flatten_to(batch_dims.len() - 1)?;
+ let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
+ let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
+ batch_dims.push(attn_scores.dim(D::Minus2)?);
+ batch_dims.push(attn_scores.dim(D::Minus1)?);
+ attn_scores.reshape(batch_dims)
+}
+
+fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result<Tensor> {
+ if dim % 2 == 1 {
+ candle::bail!("dim {dim} is odd")
+ }
+ let dev = pos.device();
+ let theta = theta as f64;
+ let inv_freq: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32)
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?;
+ let inv_freq = inv_freq.to_dtype(pos.dtype())?;
+ let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?;
+ let cos = freqs.cos()?;
+ let sin = freqs.sin()?;
+ let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?;
+ let (b, n, d, _ij) = out.dims4()?;
+ out.reshape((b, n, d, 2, 2))
+}
+
+fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
+ let dims = x.dims();
+ let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
+ let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
+ let x0 = x.narrow(D::Minus1, 0, 1)?;
+ let x1 = x.narrow(D::Minus1, 1, 1)?;
+ let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
+ let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
+ (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
+}
+
+fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
+ let q = apply_rope(q, pe)?.contiguous()?;
+ let k = apply_rope(k, pe)?.contiguous()?;
+ let x = scaled_dot_product_attention(&q, &k, v)?;
+ x.transpose(1, 2)?.flatten_from(2)
+}
+
+fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
+ const TIME_FACTOR: f64 = 1000.;
+ const MAX_PERIOD: f64 = 10000.;
+ if dim % 2 == 1 {
+ candle::bail!("{dim} is odd")
+ }
+ let dev = t.device();
+ let half = dim / 2;
+ let t = (t * TIME_FACTOR)?;
+ let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?;
+ let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
+ let args = t
+ .unsqueeze(1)?
+ .to_dtype(candle::DType::F32)?
+ .broadcast_mul(&freqs.unsqueeze(0)?)?;
+ let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
+ Ok(emb)
+}
+
+#[derive(Debug, Clone)]
+pub struct EmbedNd {
+ #[allow(unused)]
+ dim: usize,
+ theta: usize,
+ axes_dim: Vec<usize>,
+}
+
+impl EmbedNd {
+ fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
+ Self {
+ dim,
+ theta,
+ axes_dim,
+ }
+ }
+}
+
+impl candle::Module for EmbedNd {
+ fn forward(&self, ids: &Tensor) -> Result<Tensor> {
+ let n_axes = ids.dim(D::Minus1)?;
+ let mut emb = Vec::with_capacity(n_axes);
+ for idx in 0..n_axes {
+ let r = rope(
+ &ids.get_on_dim(D::Minus1, idx)?,
+ self.axes_dim[idx],
+ self.theta,
+ )?;
+ emb.push(r)
+ }
+ let emb = Tensor::cat(&emb, 2)?;
+ emb.unsqueeze(1)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct MlpEmbedder {
+ in_layer: Linear,
+ out_layer: Linear,
+}
+
+impl MlpEmbedder {
+ fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
+ let in_layer = candle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?;
+ let out_layer = candle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?;
+ Ok(Self {
+ in_layer,
+ out_layer,
+ })
+ }
+}
+
+impl candle::Module for MlpEmbedder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct QkNorm {
+ query_norm: RmsNorm,
+ key_norm: RmsNorm,
+}
+
+impl QkNorm {
+ fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
+ let query_norm = vb.get(dim, "query_norm.scale")?;
+ let query_norm = RmsNorm::new(query_norm, 1e-6);
+ let key_norm = vb.get(dim, "key_norm.scale")?;
+ let key_norm = RmsNorm::new(key_norm, 1e-6);
+ Ok(Self {
+ query_norm,
+ key_norm,
+ })
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Modulation {
+ lin: Linear,
+ multiplier: usize,
+}
+
+impl Modulation {
+ fn new(dim: usize, double: bool, vb: VarBuilder) -> Result<Self> {
+ let multiplier = if double { 6 } else { 3 };
+ let lin = candle_nn::linear(dim, multiplier * dim, vb.pp("lin"))?;
+ Ok(Self { lin, multiplier })
+ }
+
+ fn forward(&self, vec_: &Tensor) -> Result<Vec<Tensor>> {
+ vec_.silu()?
+ .apply(&self.lin)?
+ .unsqueeze(1)?
+ .chunk(self.multiplier, D::Minus1)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct SelfAttention {
+ qkv: Linear,
+ norm: QkNorm,
+ proj: Linear,
+ num_heads: usize,
+}
+
+impl SelfAttention {
+ fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
+ let head_dim = dim / num_heads;
+ let qkv = candle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
+ let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
+ let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?;
+ Ok(Self {
+ qkv,
+ norm,
+ proj,
+ num_heads,
+ })
+ }
+
+ fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
+ let qkv = xs.apply(&self.qkv)?;
+ let (b, l, _khd) = qkv.dims3()?;
+ let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
+ let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
+ let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
+ let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
+ let q = q.apply(&self.norm.query_norm)?;
+ let k = k.apply(&self.norm.key_norm)?;
+ Ok((q, k, v))
+ }
+
+ #[allow(unused)]
+ fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
+ let (q, k, v) = self.qkv(xs)?;
+ attention(&q, &k, &v, pe)?.apply(&self.proj)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Mlp {
+ lin1: Linear,
+ lin2: Linear,
+}
+
+impl Mlp {
+ fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
+ let lin1 = candle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?;
+ let lin2 = candle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?;
+ Ok(Self { lin1, lin2 })
+ }
+}
+
+impl candle::Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct DoubleStreamBlock {
+ img_mod: Modulation,
+ img_norm1: LayerNorm,
+ img_attn: SelfAttention,
+ img_norm2: LayerNorm,
+ img_mlp: Mlp,
+ txt_mod: Modulation,
+ txt_norm1: LayerNorm,
+ txt_attn: SelfAttention,
+ txt_norm2: LayerNorm,
+ txt_mlp: Mlp,
+}
+
+impl DoubleStreamBlock {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let h_sz = cfg.hidden_size;
+ let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
+ let img_mod = Modulation::new(h_sz, true, vb.pp("img_mod"))?;
+ let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
+ let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
+ let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
+ let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
+ let txt_mod = Modulation::new(h_sz, true, vb.pp("txt_mod"))?;
+ let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
+ let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
+ let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
+ let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
+ Ok(Self {
+ img_mod,
+ img_norm1,
+ img_attn,
+ img_norm2,
+ img_mlp,
+ txt_mod,
+ txt_norm1,
+ txt_attn,
+ txt_norm2,
+ txt_mlp,
+ })
+ }
+
+ fn forward(
+ &self,
+ img: &Tensor,
+ txt: &Tensor,
+ vec_: &Tensor,
+ pe: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ let img_mod = self.img_mod.forward(vec_)?; // shift, scale, gate
+ let txt_mod = self.txt_mod.forward(vec_)?; // shift, scale, gate
+ let img_modulated = img.apply(&self.img_norm1)?;
+ let img_modulated = img_modulated
+ .broadcast_mul(&(&img_mod[1] + 1.)?)?
+ .broadcast_add(&img_mod[0])?;
+ let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
+
+ let txt_modulated = txt.apply(&self.txt_norm1)?;
+ let txt_modulated = txt_modulated
+ .broadcast_mul(&(&txt_mod[1] + 1.)?)?
+ .broadcast_add(&txt_mod[0])?;
+ let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
+
+ let q = Tensor::cat(&[txt_q, img_q], 2)?;
+ let k = Tensor::cat(&[txt_k, img_k], 2)?;
+ let v = Tensor::cat(&[txt_v, img_v], 2)?;
+
+ let attn = attention(&q, &k, &v, pe)?;
+ let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
+ let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
+
+ let img = (img
+ + img_attn
+ .apply(&self.img_attn.proj)?
+ .broadcast_mul(&img_mod[2]))?;
+ let img = (&img
+ + &img_mod[5].broadcast_mul(
+ &img.apply(&self.img_norm2)?
+ .broadcast_mul(&(&img_mod[4] + 1.0)?)?
+ .broadcast_add(&img_mod[3])?
+ .apply(&self.img_mlp)?,
+ )?)?;
+
+ let txt = (txt
+ + txt_attn
+ .apply(&self.txt_attn.proj)?
+ .broadcast_mul(&txt_mod[2]))?;
+ let txt = (&txt
+ + &txt_mod[5].broadcast_mul(
+ &txt.apply(&self.txt_norm2)?
+ .broadcast_mul(&(&txt_mod[4] + 1.0)?)?
+ .broadcast_add(&txt_mod[3])?
+ .apply(&self.txt_mlp)?,
+ )?)?;
+
+ Ok((img, txt))
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct SingleStreamBlock {
+ linear1: Linear,
+ linear2: Linear,
+ norm: QkNorm,
+ pre_norm: LayerNorm,
+ modulation: Modulation,
+ h_sz: usize,
+ mlp_sz: usize,
+ num_heads: usize,
+}
+
+impl SingleStreamBlock {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let h_sz = cfg.hidden_size;
+ let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
+ let head_dim = h_sz / cfg.num_heads;
+ let linear1 = candle_nn::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
+ let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
+ let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
+ let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
+ let modulation = Modulation::new(h_sz, false, vb.pp("modulation"))?;
+ Ok(Self {
+ linear1,
+ linear2,
+ norm,
+ pre_norm,
+ modulation,
+ h_sz,
+ mlp_sz,
+ num_heads: cfg.num_heads,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
+ let mod_ = self.modulation.forward(vec_)?;
+ let (shift, scale, gate) = (&mod_[0], &mod_[1], &mod_[2]);
+ let x_mod = xs
+ .apply(&self.pre_norm)?
+ .broadcast_mul(&(scale + 1.0)?)?
+ .broadcast_add(shift)?;
+ let x_mod = x_mod.apply(&self.linear1)?;
+ let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
+ let (b, l, _khd) = qkv.dims3()?;
+ let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
+ let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
+ let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
+ let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
+ let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
+ let q = q.apply(&self.norm.query_norm)?;
+ let k = k.apply(&self.norm.key_norm)?;
+ let attn = attention(&q, &k, &v, pe)?;
+ let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
+ xs + gate.broadcast_mul(&output)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct LastLayer {
+ norm_final: LayerNorm,
+ linear: Linear,
+ ada_ln_modulation: Linear,
+}
+
+impl LastLayer {
+ fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
+ let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
+ let linear = candle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
+ let ada_ln_modulation = candle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
+ Ok(Self {
+ norm_final,
+ linear,
+ ada_ln_modulation,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
+ let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
+ let (shift, scale) = (&chunks[0], &chunks[1]);
+ let xs = xs
+ .apply(&self.norm_final)?
+ .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
+ .broadcast_add(&shift.unsqueeze(1)?)?;
+ xs.apply(&self.linear)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Flux {
+ img_in: Linear,
+ txt_in: Linear,
+ time_in: MlpEmbedder,
+ vector_in: MlpEmbedder,
+ guidance_in: Option<MlpEmbedder>,
+ pe_embedder: EmbedNd,
+ double_blocks: Vec<DoubleStreamBlock>,
+ single_blocks: Vec<SingleStreamBlock>,
+ final_layer: LastLayer,
+}
+
+impl Flux {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let img_in = candle_nn::linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
+ let txt_in = candle_nn::linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
+ let mut double_blocks = Vec::with_capacity(cfg.depth);
+ let vb_d = vb.pp("double_blocks");
+ for idx in 0..cfg.depth {
+ let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
+ double_blocks.push(db)
+ }
+ let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
+ let vb_s = vb.pp("single_blocks");
+ for idx in 0..cfg.depth_single_blocks {
+ let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
+ single_blocks.push(sb)
+ }
+ let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
+ let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
+ let guidance_in = if cfg.guidance_embed {
+ let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
+ Some(mlp)
+ } else {
+ None
+ };
+ let final_layer =
+ LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
+ let pe_dim = cfg.hidden_size / cfg.num_heads;
+ let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
+ Ok(Self {
+ img_in,
+ txt_in,
+ time_in,
+ vector_in,
+ guidance_in,
+ pe_embedder,
+ double_blocks,
+ single_blocks,
+ final_layer,
+ })
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub fn forward(
+ &self,
+ img: &Tensor,
+ img_ids: &Tensor,
+ txt: &Tensor,
+ txt_ids: &Tensor,
+ timesteps: &Tensor,
+ y: &Tensor,
+ guidance: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ if txt.rank() != 3 {
+ candle::bail!("unexpected shape for txt {:?}", txt.shape())
+ }
+ if img.rank() != 3 {
+ candle::bail!("unexpected shape for img {:?}", img.shape())
+ }
+ let dtype = img.dtype();
+ let pe = {
+ let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
+ ids.apply(&self.pe_embedder)?
+ };
+ let mut txt = txt.apply(&self.txt_in)?;
+ let mut img = img.apply(&self.img_in)?;
+ let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
+ let vec_ = match (self.guidance_in.as_ref(), guidance) {
+ (Some(g_in), Some(guidance)) => {
+ (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
+ }
+ _ => vec_,
+ };
+ let vec_ = (vec_ + y.apply(&self.vector_in))?;
+
+ // Double blocks
+ for block in self.double_blocks.iter() {
+ (img, txt) = block.forward(&img, &txt, &vec_, &pe)?
+ }
+ // Single blocks
+ let mut img = Tensor::cat(&[&txt, &img], 1)?;
+ for block in self.single_blocks.iter() {
+ img = block.forward(&img, &vec_, &pe)?;
+ }
+ let img = img.i((.., txt.dim(1)?..))?;
+ self.final_layer.forward(&img, &vec_)
+ }
+}
diff --git a/candle-transformers/src/models/flux/sampling.rs b/candle-transformers/src/models/flux/sampling.rs
new file mode 100644
index 00000000..89b9a953
--- /dev/null
+++ b/candle-transformers/src/models/flux/sampling.rs
@@ -0,0 +1,119 @@
+use candle::{Device, Result, Tensor};
+
+pub fn get_noise(
+ num_samples: usize,
+ height: usize,
+ width: usize,
+ device: &Device,
+) -> Result<Tensor> {
+ let height = (height + 15) / 16 * 2;
+ let width = (width + 15) / 16 * 2;
+ Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
+}
+
+#[derive(Debug, Clone)]
+pub struct State {
+ pub img: Tensor,
+ pub img_ids: Tensor,
+ pub txt: Tensor,
+ pub txt_ids: Tensor,
+ pub vec: Tensor,
+}
+
+impl State {
+ pub fn new(t5_emb: &Tensor, clip_emb: &Tensor, img: &Tensor) -> Result<Self> {
+ let dtype = img.dtype();
+ let (bs, c, h, w) = img.dims4()?;
+ let dev = img.device();
+ let img = img.reshape((bs, c, h / 2, 2, w / 2, 2))?; // (b, c, h, ph, w, pw)
+ let img = img.permute((0, 2, 4, 1, 3, 5))?; // (b, h, w, c, ph, pw)
+ let img = img.reshape((bs, h / 2 * w / 2, c * 4))?;
+ let img_ids = Tensor::stack(
+ &[
+ Tensor::full(0u32, (h / 2, w / 2), dev)?,
+ Tensor::arange(0u32, h as u32 / 2, dev)?
+ .reshape(((), 1))?
+ .broadcast_as((h / 2, w / 2))?,
+ Tensor::arange(0u32, w as u32 / 2, dev)?
+ .reshape((1, ()))?
+ .broadcast_as((h / 2, w / 2))?,
+ ],
+ 2,
+ )?
+ .to_dtype(dtype)?;
+ let img_ids = img_ids.reshape((1, h / 2 * w / 2, 3))?;
+ let img_ids = img_ids.repeat((bs, 1, 1))?;
+ let txt = t5_emb.repeat(bs)?;
+ let txt_ids = Tensor::zeros((bs, txt.dim(1)?, 3), dtype, dev)?;
+ let vec = clip_emb.repeat(bs)?;
+ Ok(Self {
+ img,
+ img_ids,
+ txt,
+ txt_ids,
+ vec,
+ })
+ }
+}
+
+fn time_shift(mu: f64, sigma: f64, t: f64) -> f64 {
+ let e = mu.exp();
+ e / (e + (1. / t - 1.).powf(sigma))
+}
+
+/// `shift` is a triple `(image_seq_len, base_shift, max_shift)`.
+pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f64> {
+ let timesteps: Vec<f64> = (0..=num_steps)
+ .map(|v| v as f64 / num_steps as f64)
+ .rev()
+ .collect();
+ match shift {
+ None => timesteps,
+ Some((image_seq_len, y1, y2)) => {
+ let (x1, x2) = (256., 4096.);
+ let m = (y2 - y1) / (x2 - x1);
+ let b = y1 - m * x1;
+ let mu = m * image_seq_len as f64 + b;
+ timesteps
+ .into_iter()
+ .map(|v| time_shift(mu, 1., v))
+ .collect()
+ }
+ }
+}
+
+pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
+ let (b, _h_w, c_ph_pw) = xs.dims3()?;
+ let height = (height + 15) / 16;
+ let width = (width + 15) / 16;
+ xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
+ .permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
+ .reshape((b, c_ph_pw / 4, height * 2, width * 2))
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn denoise(
+ model: &super::model::Flux,
+ img: &Tensor,
+ img_ids: &Tensor,
+ txt: &Tensor,
+ txt_ids: &Tensor,
+ vec_: &Tensor,
+ timesteps: &[f64],
+ guidance: f64,
+) -> Result<Tensor> {
+ let b_sz = img.dim(0)?;
+ let dev = img.device();
+ let guidance = Tensor::full(guidance as f32, b_sz, dev)?;
+ let mut img = img.clone();
+ for window in timesteps.windows(2) {
+ let (t_curr, t_prev) = match window {
+ [a, b] => (a, b),
+ _ => continue,
+ };
+ let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?;
+ let pred = model.forward(&img, img_ids, txt, txt_ids, &t_vec, vec_, Some(&guidance))?;
+ img = (img + pred * (t_prev - t_curr))?
+ }
+ Ok(img)
+}