use super::model::{attention, timestep_embedding, Config, EmbedNd}; use crate::quantized_nn::{linear, linear_b, Linear}; use crate::quantized_var_builder::VarBuilder; use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn::{LayerNorm, RmsNorm}; fn layer_norm(dim: usize, vb: VarBuilder) -> Result { let ws = Tensor::ones(dim, DType::F32, vb.device())?; Ok(LayerNorm::new_no_bias(ws, 1e-6)) } #[derive(Debug, Clone)] pub struct MlpEmbedder { in_layer: Linear, out_layer: Linear, } impl MlpEmbedder { fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result { let in_layer = linear(in_sz, h_sz, vb.pp("in_layer"))?; let out_layer = linear(h_sz, h_sz, vb.pp("out_layer"))?; Ok(Self { in_layer, out_layer, }) } } impl candle::Module for MlpEmbedder { fn forward(&self, xs: &Tensor) -> Result { xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer) } } #[derive(Debug, Clone)] pub struct QkNorm { query_norm: RmsNorm, key_norm: RmsNorm, } impl QkNorm { fn new(dim: usize, vb: VarBuilder) -> Result { let query_norm = vb.get(dim, "query_norm.scale")?.dequantize(vb.device())?; let query_norm = RmsNorm::new(query_norm, 1e-6); let key_norm = vb.get(dim, "key_norm.scale")?.dequantize(vb.device())?; let key_norm = RmsNorm::new(key_norm, 1e-6); Ok(Self { query_norm, key_norm, }) } } struct ModulationOut { shift: Tensor, scale: Tensor, gate: Tensor, } impl ModulationOut { fn scale_shift(&self, xs: &Tensor) -> Result { xs.broadcast_mul(&(&self.scale + 1.)?)? .broadcast_add(&self.shift) } fn gate(&self, xs: &Tensor) -> Result { self.gate.broadcast_mul(xs) } } #[derive(Debug, Clone)] struct Modulation1 { lin: Linear, } impl Modulation1 { fn new(dim: usize, vb: VarBuilder) -> Result { let lin = linear(dim, 3 * dim, vb.pp("lin"))?; Ok(Self { lin }) } fn forward(&self, vec_: &Tensor) -> Result { let ys = vec_ .silu()? .apply(&self.lin)? .unsqueeze(1)? .chunk(3, D::Minus1)?; if ys.len() != 3 { candle::bail!("unexpected len from chunk {ys:?}") } Ok(ModulationOut { shift: ys[0].clone(), scale: ys[1].clone(), gate: ys[2].clone(), }) } } #[derive(Debug, Clone)] struct Modulation2 { lin: Linear, } impl Modulation2 { fn new(dim: usize, vb: VarBuilder) -> Result { let lin = linear(dim, 6 * dim, vb.pp("lin"))?; Ok(Self { lin }) } fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> { let ys = vec_ .silu()? .apply(&self.lin)? .unsqueeze(1)? .chunk(6, D::Minus1)?; if ys.len() != 6 { candle::bail!("unexpected len from chunk {ys:?}") } let mod1 = ModulationOut { shift: ys[0].clone(), scale: ys[1].clone(), gate: ys[2].clone(), }; let mod2 = ModulationOut { shift: ys[3].clone(), scale: ys[4].clone(), gate: ys[5].clone(), }; Ok((mod1, mod2)) } } #[derive(Debug, Clone)] pub struct SelfAttention { qkv: Linear, norm: QkNorm, proj: Linear, num_heads: usize, } impl SelfAttention { fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result { let head_dim = dim / num_heads; let qkv = linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?; let norm = QkNorm::new(head_dim, vb.pp("norm"))?; let proj = linear(dim, dim, vb.pp("proj"))?; Ok(Self { qkv, norm, proj, num_heads, }) } fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { let qkv = xs.apply(&self.qkv)?; let (b, l, _khd) = qkv.dims3()?; let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?; let q = qkv.i((.., .., 0))?.transpose(1, 2)?; let k = qkv.i((.., .., 1))?.transpose(1, 2)?; let v = qkv.i((.., .., 2))?.transpose(1, 2)?; let q = q.apply(&self.norm.query_norm)?; let k = k.apply(&self.norm.key_norm)?; Ok((q, k, v)) } #[allow(unused)] fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result { let (q, k, v) = self.qkv(xs)?; attention(&q, &k, &v, pe)?.apply(&self.proj) } } #[derive(Debug, Clone)] struct Mlp { lin1: Linear, lin2: Linear, } impl Mlp { fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result { let lin1 = linear(in_sz, mlp_sz, vb.pp("0"))?; let lin2 = linear(mlp_sz, in_sz, vb.pp("2"))?; Ok(Self { lin1, lin2 }) } } impl candle::Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) } } #[derive(Debug, Clone)] pub struct DoubleStreamBlock { img_mod: Modulation2, img_norm1: LayerNorm, img_attn: SelfAttention, img_norm2: LayerNorm, img_mlp: Mlp, txt_mod: Modulation2, txt_norm1: LayerNorm, txt_attn: SelfAttention, txt_norm2: LayerNorm, txt_mlp: Mlp, } impl DoubleStreamBlock { fn new(cfg: &Config, vb: VarBuilder) -> Result { let h_sz = cfg.hidden_size; let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?; let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?; let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?; let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?; let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?; let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?; let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?; let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?; let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?; let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?; Ok(Self { img_mod, img_norm1, img_attn, img_norm2, img_mlp, txt_mod, txt_norm1, txt_attn, txt_norm2, txt_mlp, }) } fn forward( &self, img: &Tensor, txt: &Tensor, vec_: &Tensor, pe: &Tensor, ) -> Result<(Tensor, Tensor)> { let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate let img_modulated = img.apply(&self.img_norm1)?; let img_modulated = img_mod1.scale_shift(&img_modulated)?; let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?; let txt_modulated = txt.apply(&self.txt_norm1)?; let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?; let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?; let q = Tensor::cat(&[txt_q, img_q], 2)?; let k = Tensor::cat(&[txt_k, img_k], 2)?; let v = Tensor::cat(&[txt_v, img_v], 2)?; let attn = attention(&q, &k, &v, pe)?; let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?; let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?; let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?; let img = (&img + img_mod2.gate( &img_mod2 .scale_shift(&img.apply(&self.img_norm2)?)? .apply(&self.img_mlp)?, )?)?; let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?; let txt = (&txt + txt_mod2.gate( &txt_mod2 .scale_shift(&txt.apply(&self.txt_norm2)?)? .apply(&self.txt_mlp)?, )?)?; Ok((img, txt)) } } #[derive(Debug, Clone)] pub struct SingleStreamBlock { linear1: Linear, linear2: Linear, norm: QkNorm, pre_norm: LayerNorm, modulation: Modulation1, h_sz: usize, mlp_sz: usize, num_heads: usize, } impl SingleStreamBlock { fn new(cfg: &Config, vb: VarBuilder) -> Result { let h_sz = cfg.hidden_size; let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; let head_dim = h_sz / cfg.num_heads; let linear1 = linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?; let linear2 = linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?; let norm = QkNorm::new(head_dim, vb.pp("norm"))?; let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?; let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?; Ok(Self { linear1, linear2, norm, pre_norm, modulation, h_sz, mlp_sz, num_heads: cfg.num_heads, }) } fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result { let mod_ = self.modulation.forward(vec_)?; let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?; let x_mod = x_mod.apply(&self.linear1)?; let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?; let (b, l, _khd) = qkv.dims3()?; let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?; let q = qkv.i((.., .., 0))?.transpose(1, 2)?; let k = qkv.i((.., .., 1))?.transpose(1, 2)?; let v = qkv.i((.., .., 2))?.transpose(1, 2)?; let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?; let q = q.apply(&self.norm.query_norm)?; let k = k.apply(&self.norm.key_norm)?; let attn = attention(&q, &k, &v, pe)?; let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?; xs + mod_.gate(&output) } } #[derive(Debug, Clone)] pub struct LastLayer { norm_final: LayerNorm, linear: Linear, ada_ln_modulation: Linear, } impl LastLayer { fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result { let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?; let linear_ = linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?; let ada_ln_modulation = linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?; Ok(Self { norm_final, linear: linear_, ada_ln_modulation, }) } fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result { let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?; let (shift, scale) = (&chunks[0], &chunks[1]); let xs = xs .apply(&self.norm_final)? .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)? .broadcast_add(&shift.unsqueeze(1)?)?; xs.apply(&self.linear) } } #[derive(Debug, Clone)] pub struct Flux { img_in: Linear, txt_in: Linear, time_in: MlpEmbedder, vector_in: MlpEmbedder, guidance_in: Option, pe_embedder: EmbedNd, double_blocks: Vec, single_blocks: Vec, final_layer: LastLayer, } impl Flux { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let img_in = linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?; let txt_in = linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?; let mut double_blocks = Vec::with_capacity(cfg.depth); let vb_d = vb.pp("double_blocks"); for idx in 0..cfg.depth { let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?; double_blocks.push(db) } let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks); let vb_s = vb.pp("single_blocks"); for idx in 0..cfg.depth_single_blocks { let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?; single_blocks.push(sb) } let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?; let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?; let guidance_in = if cfg.guidance_embed { let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?; Some(mlp) } else { None }; let final_layer = LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?; let pe_dim = cfg.hidden_size / cfg.num_heads; let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec()); Ok(Self { img_in, txt_in, time_in, vector_in, guidance_in, pe_embedder, double_blocks, single_blocks, final_layer, }) } } impl super::WithForward for Flux { #[allow(clippy::too_many_arguments)] fn forward( &self, img: &Tensor, img_ids: &Tensor, txt: &Tensor, txt_ids: &Tensor, timesteps: &Tensor, y: &Tensor, guidance: Option<&Tensor>, ) -> Result { if txt.rank() != 3 { candle::bail!("unexpected shape for txt {:?}", txt.shape()) } if img.rank() != 3 { candle::bail!("unexpected shape for img {:?}", img.shape()) } let dtype = img.dtype(); let pe = { let ids = Tensor::cat(&[txt_ids, img_ids], 1)?; ids.apply(&self.pe_embedder)? }; let mut txt = txt.apply(&self.txt_in)?; let mut img = img.apply(&self.img_in)?; let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?; let vec_ = match (self.guidance_in.as_ref(), guidance) { (Some(g_in), Some(guidance)) => { (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))? } _ => vec_, }; let vec_ = (vec_ + y.apply(&self.vector_in))?; // Double blocks for block in self.double_blocks.iter() { (img, txt) = block.forward(&img, &txt, &vec_, &pe)? } // Single blocks let mut img = Tensor::cat(&[&txt, &img], 1)?; for block in self.single_blocks.iter() { img = block.forward(&img, &vec_, &pe)?; } let img = img.i((.., txt.dim(1)?..))?; self.final_layer.forward(&img, &vec_) } }