summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/flux/autoencoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/flux/autoencoder.rs')
-rw-r--r--candle-transformers/src/models/flux/autoencoder.rs440
1 files changed, 440 insertions, 0 deletions
diff --git a/candle-transformers/src/models/flux/autoencoder.rs b/candle-transformers/src/models/flux/autoencoder.rs
new file mode 100644
index 00000000..8c2aebbd
--- /dev/null
+++ b/candle-transformers/src/models/flux/autoencoder.rs
@@ -0,0 +1,440 @@
+use candle::{Result, Tensor, D};
+use candle_nn::{conv2d, group_norm, Conv2d, GroupNorm, VarBuilder};
+
+// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/modules/autoencoder.py#L9
+#[derive(Debug, Clone)]
+pub struct Config {
+ pub resolution: usize,
+ pub in_channels: usize,
+ pub ch: usize,
+ pub out_ch: usize,
+ pub ch_mult: Vec<usize>,
+ pub num_res_blocks: usize,
+ pub z_channels: usize,
+ pub scale_factor: f64,
+ pub shift_factor: f64,
+}
+
+impl Config {
+ // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L47
+ pub fn dev() -> Self {
+ Self {
+ resolution: 256,
+ in_channels: 3,
+ ch: 128,
+ out_ch: 3,
+ ch_mult: vec![1, 2, 4, 4],
+ num_res_blocks: 2,
+ z_channels: 16,
+ scale_factor: 0.3611,
+ shift_factor: 0.1159,
+ }
+ }
+
+ // https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L79
+ pub fn schnell() -> Self {
+ Self {
+ resolution: 256,
+ in_channels: 3,
+ ch: 128,
+ out_ch: 3,
+ ch_mult: vec![1, 2, 4, 4],
+ num_res_blocks: 2,
+ z_channels: 16,
+ scale_factor: 0.3611,
+ shift_factor: 0.1159,
+ }
+ }
+}
+
+fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
+ let dim = q.dim(D::Minus1)?;
+ let scale_factor = 1.0 / (dim as f64).sqrt();
+ let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
+ candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
+}
+
+#[derive(Debug, Clone)]
+struct AttnBlock {
+ q: Conv2d,
+ k: Conv2d,
+ v: Conv2d,
+ proj_out: Conv2d,
+ norm: GroupNorm,
+}
+
+impl AttnBlock {
+ fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
+ let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?;
+ let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?;
+ let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?;
+ let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?;
+ let norm = group_norm(32, in_c, 1e-6, vb.pp("norm"))?;
+ Ok(Self {
+ q,
+ k,
+ v,
+ proj_out,
+ norm,
+ })
+ }
+}
+
+impl candle::Module for AttnBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let init_xs = xs;
+ let xs = xs.apply(&self.norm)?;
+ let q = xs.apply(&self.q)?;
+ let k = xs.apply(&self.k)?;
+ let v = xs.apply(&self.v)?;
+ let (b, c, h, w) = q.dims4()?;
+ let q = q.flatten_from(2)?.t()?.unsqueeze(1)?;
+ let k = k.flatten_from(2)?.t()?.unsqueeze(1)?;
+ let v = v.flatten_from(2)?.t()?.unsqueeze(1)?;
+ let xs = scaled_dot_product_attention(&q, &k, &v)?;
+ let xs = xs.squeeze(1)?.t()?.reshape((b, c, h, w))?;
+ xs.apply(&self.proj_out)? + init_xs
+ }
+}
+
+#[derive(Debug, Clone)]
+struct ResnetBlock {
+ norm1: GroupNorm,
+ conv1: Conv2d,
+ norm2: GroupNorm,
+ conv2: Conv2d,
+ nin_shortcut: Option<Conv2d>,
+}
+
+impl ResnetBlock {
+ fn new(in_c: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
+ let conv_cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let norm1 = group_norm(32, in_c, 1e-6, vb.pp("norm1"))?;
+ let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?;
+ let norm2 = group_norm(32, out_c, 1e-6, vb.pp("norm2"))?;
+ let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?;
+ let nin_shortcut = if in_c == out_c {
+ None
+ } else {
+ Some(conv2d(
+ in_c,
+ out_c,
+ 1,
+ Default::default(),
+ vb.pp("nin_shortcut"),
+ )?)
+ };
+ Ok(Self {
+ norm1,
+ conv1,
+ norm2,
+ conv2,
+ nin_shortcut,
+ })
+ }
+}
+
+impl candle::Module for ResnetBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let h = xs
+ .apply(&self.norm1)?
+ .apply(&candle_nn::Activation::Swish)?
+ .apply(&self.conv1)?
+ .apply(&self.norm2)?
+ .apply(&candle_nn::Activation::Swish)?
+ .apply(&self.conv2)?;
+ match self.nin_shortcut.as_ref() {
+ None => xs + h,
+ Some(c) => xs.apply(c)? + h,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Downsample {
+ conv: Conv2d,
+}
+
+impl Downsample {
+ fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
+ let conv_cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
+ Ok(Self { conv })
+ }
+}
+
+impl candle::Module for Downsample {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
+ let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
+ xs.apply(&self.conv)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Upsample {
+ conv: Conv2d,
+}
+
+impl Upsample {
+ fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
+ let conv_cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
+ Ok(Self { conv })
+ }
+}
+
+impl candle::Module for Upsample {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (_, _, h, w) = xs.dims4()?;
+ xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DownBlock {
+ block: Vec<ResnetBlock>,
+ downsample: Option<Downsample>,
+}
+
+#[derive(Debug, Clone)]
+pub struct Encoder {
+ conv_in: Conv2d,
+ mid_block_1: ResnetBlock,
+ mid_attn_1: AttnBlock,
+ mid_block_2: ResnetBlock,
+ norm_out: GroupNorm,
+ conv_out: Conv2d,
+ down: Vec<DownBlock>,
+}
+
+impl Encoder {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let conv_cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let mut block_in = cfg.ch;
+ let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
+
+ let mut down = Vec::with_capacity(cfg.ch_mult.len());
+ let vb_d = vb.pp("down");
+ for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate() {
+ let mut block = Vec::with_capacity(cfg.num_res_blocks);
+ let vb_d = vb_d.pp(i_level);
+ let vb_b = vb_d.pp("block");
+ let in_ch_mult = if i_level == 0 {
+ 1
+ } else {
+ cfg.ch_mult[i_level - 1]
+ };
+ block_in = cfg.ch * in_ch_mult;
+ let block_out = cfg.ch * ch_mult;
+ for i_block in 0..cfg.num_res_blocks {
+ let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
+ block.push(b);
+ block_in = block_out;
+ }
+ let downsample = if i_level != cfg.ch_mult.len() - 1 {
+ Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
+ } else {
+ None
+ };
+ let block = DownBlock { block, downsample };
+ down.push(block)
+ }
+
+ let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
+ let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
+ let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
+ let conv_out = conv2d(block_in, 2 * cfg.z_channels, 3, conv_cfg, vb.pp("conv_out"))?;
+ let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
+ Ok(Self {
+ conv_in,
+ mid_block_1,
+ mid_attn_1,
+ mid_block_2,
+ norm_out,
+ conv_out,
+ down,
+ })
+ }
+}
+
+impl candle_nn::Module for Encoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut h = xs.apply(&self.conv_in)?;
+ for block in self.down.iter() {
+ for b in block.block.iter() {
+ h = h.apply(b)?
+ }
+ if let Some(ds) = block.downsample.as_ref() {
+ h = h.apply(ds)?
+ }
+ }
+ h.apply(&self.mid_block_1)?
+ .apply(&self.mid_attn_1)?
+ .apply(&self.mid_block_2)?
+ .apply(&self.norm_out)?
+ .apply(&candle_nn::Activation::Swish)?
+ .apply(&self.conv_out)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct UpBlock {
+ block: Vec<ResnetBlock>,
+ upsample: Option<Upsample>,
+}
+
+#[derive(Debug, Clone)]
+pub struct Decoder {
+ conv_in: Conv2d,
+ mid_block_1: ResnetBlock,
+ mid_attn_1: AttnBlock,
+ mid_block_2: ResnetBlock,
+ norm_out: GroupNorm,
+ conv_out: Conv2d,
+ up: Vec<UpBlock>,
+}
+
+impl Decoder {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let conv_cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let mut block_in = cfg.ch * cfg.ch_mult.last().unwrap_or(&1);
+ let conv_in = conv2d(cfg.z_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
+ let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
+ let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
+ let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
+
+ let mut up = Vec::with_capacity(cfg.ch_mult.len());
+ let vb_u = vb.pp("up");
+ for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate().rev() {
+ let block_out = cfg.ch * ch_mult;
+ let vb_u = vb_u.pp(i_level);
+ let vb_b = vb_u.pp("block");
+ let mut block = Vec::with_capacity(cfg.num_res_blocks + 1);
+ for i_block in 0..=cfg.num_res_blocks {
+ let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
+ block.push(b);
+ block_in = block_out;
+ }
+ let upsample = if i_level != 0 {
+ Some(Upsample::new(block_in, vb_u.pp("upsample"))?)
+ } else {
+ None
+ };
+ let block = UpBlock { block, upsample };
+ up.push(block)
+ }
+ up.reverse();
+
+ let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
+ let conv_out = conv2d(block_in, cfg.out_ch, 3, conv_cfg, vb.pp("conv_out"))?;
+ Ok(Self {
+ conv_in,
+ mid_block_1,
+ mid_attn_1,
+ mid_block_2,
+ norm_out,
+ conv_out,
+ up,
+ })
+ }
+}
+
+impl candle_nn::Module for Decoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let h = xs.apply(&self.conv_in)?;
+ let mut h = h
+ .apply(&self.mid_block_1)?
+ .apply(&self.mid_attn_1)?
+ .apply(&self.mid_block_2)?;
+ for block in self.up.iter().rev() {
+ for b in block.block.iter() {
+ h = h.apply(b)?
+ }
+ if let Some(us) = block.upsample.as_ref() {
+ h = h.apply(us)?
+ }
+ }
+ h.apply(&self.norm_out)?
+ .apply(&candle_nn::Activation::Swish)?
+ .apply(&self.conv_out)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct DiagonalGaussian {
+ sample: bool,
+ chunk_dim: usize,
+}
+
+impl DiagonalGaussian {
+ pub fn new(sample: bool, chunk_dim: usize) -> Result<Self> {
+ Ok(Self { sample, chunk_dim })
+ }
+}
+
+impl candle_nn::Module for DiagonalGaussian {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let chunks = xs.chunk(2, self.chunk_dim)?;
+ if self.sample {
+ let std = (&chunks[1] * 0.5)?.exp()?;
+ &chunks[0] + (std * chunks[0].randn_like(0., 1.))?
+ } else {
+ Ok(chunks[0].clone())
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct AutoEncoder {
+ encoder: Encoder,
+ decoder: Decoder,
+ reg: DiagonalGaussian,
+ shift_factor: f64,
+ scale_factor: f64,
+}
+
+impl AutoEncoder {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
+ let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
+ let reg = DiagonalGaussian::new(true, 1)?;
+ Ok(Self {
+ encoder,
+ decoder,
+ reg,
+ scale_factor: cfg.scale_factor,
+ shift_factor: cfg.shift_factor,
+ })
+ }
+
+ pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
+ let z = xs.apply(&self.encoder)?.apply(&self.reg)?;
+ (z - self.shift_factor)? * self.scale_factor
+ }
+ pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
+ xs.apply(&self.decoder)
+ }
+}
+
+impl candle::Module for AutoEncoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.decode(&self.encode(xs)?)
+ }
+}