//! 2D UNet Denoising Models //! //! The 2D Unet models take as input a noisy sample and the current diffusion //! timestep and return a denoised version of the input. use super::embeddings::{TimestepEmbedding, Timesteps}; use super::unet_2d_blocks::*; use crate::models::with_tracing::{conv2d, Conv2d}; use candle::{Result, Tensor}; use candle_nn as nn; use candle_nn::Module; #[derive(Debug, Clone, Copy)] pub struct BlockConfig { pub out_channels: usize, /// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the /// number of transformer blocks to be used. pub use_cross_attn: Option, pub attention_head_dim: usize, } #[derive(Debug, Clone)] pub struct UNet2DConditionModelConfig { pub center_input_sample: bool, pub flip_sin_to_cos: bool, pub freq_shift: f64, pub blocks: Vec, pub layers_per_block: usize, pub downsample_padding: usize, pub mid_block_scale_factor: f64, pub norm_num_groups: usize, pub norm_eps: f64, pub cross_attention_dim: usize, pub sliced_attention_size: Option, pub use_linear_projection: bool, } impl Default for UNet2DConditionModelConfig { fn default() -> Self { Self { center_input_sample: false, flip_sin_to_cos: true, freq_shift: 0., blocks: vec![ BlockConfig { out_channels: 320, use_cross_attn: Some(1), attention_head_dim: 8, }, BlockConfig { out_channels: 640, use_cross_attn: Some(1), attention_head_dim: 8, }, BlockConfig { out_channels: 1280, use_cross_attn: Some(1), attention_head_dim: 8, }, BlockConfig { out_channels: 1280, use_cross_attn: None, attention_head_dim: 8, }, ], layers_per_block: 2, downsample_padding: 1, mid_block_scale_factor: 1., norm_num_groups: 32, norm_eps: 1e-5, cross_attention_dim: 1280, sliced_attention_size: None, use_linear_projection: false, } } } #[derive(Debug)] pub(crate) enum UNetDownBlock { Basic(DownBlock2D), CrossAttn(CrossAttnDownBlock2D), } #[derive(Debug)] enum UNetUpBlock { Basic(UpBlock2D), CrossAttn(CrossAttnUpBlock2D), } #[derive(Debug)] pub struct UNet2DConditionModel { conv_in: Conv2d, time_proj: Timesteps, time_embedding: TimestepEmbedding, down_blocks: Vec, mid_block: UNetMidBlock2DCrossAttn, up_blocks: Vec, conv_norm_out: nn::GroupNorm, conv_out: Conv2d, span: tracing::Span, config: UNet2DConditionModelConfig, } impl UNet2DConditionModel { pub fn new( vs: nn::VarBuilder, in_channels: usize, out_channels: usize, use_flash_attn: bool, config: UNet2DConditionModelConfig, ) -> Result { let n_blocks = config.blocks.len(); let b_channels = config.blocks[0].out_channels; let bl_channels = config.blocks.last().unwrap().out_channels; let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim; let time_embed_dim = b_channels * 4; let conv_cfg = nn::Conv2dConfig { padding: 1, ..Default::default() }; let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?; let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift); let time_embedding = TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?; let vs_db = vs.pp("down_blocks"); let down_blocks = (0..n_blocks) .map(|i| { let BlockConfig { out_channels, use_cross_attn, attention_head_dim, } = config.blocks[i]; // Enable automatic attention slicing if the config sliced_attention_size is set to 0. let sliced_attention_size = match config.sliced_attention_size { Some(0) => Some(attention_head_dim / 2), _ => config.sliced_attention_size, }; let in_channels = if i > 0 { config.blocks[i - 1].out_channels } else { b_channels }; let db_cfg = DownBlock2DConfig { num_layers: config.layers_per_block, resnet_eps: config.norm_eps, resnet_groups: config.norm_num_groups, add_downsample: i < n_blocks - 1, downsample_padding: config.downsample_padding, ..Default::default() }; if let Some(transformer_layers_per_block) = use_cross_attn { let config = CrossAttnDownBlock2DConfig { downblock: db_cfg, attn_num_head_channels: attention_head_dim, cross_attention_dim: config.cross_attention_dim, sliced_attention_size, use_linear_projection: config.use_linear_projection, transformer_layers_per_block, }; let block = CrossAttnDownBlock2D::new( vs_db.pp(i.to_string()), in_channels, out_channels, Some(time_embed_dim), use_flash_attn, config, )?; Ok(UNetDownBlock::CrossAttn(block)) } else { let block = DownBlock2D::new( vs_db.pp(i.to_string()), in_channels, out_channels, Some(time_embed_dim), db_cfg, )?; Ok(UNetDownBlock::Basic(block)) } }) .collect::>>()?; // https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462 let mid_transformer_layers_per_block = match config.blocks.last() { None => 1, Some(block) => block.use_cross_attn.unwrap_or(1), }; let mid_cfg = UNetMidBlock2DCrossAttnConfig { resnet_eps: config.norm_eps, output_scale_factor: config.mid_block_scale_factor, cross_attn_dim: config.cross_attention_dim, attn_num_head_channels: bl_attention_head_dim, resnet_groups: Some(config.norm_num_groups), use_linear_projection: config.use_linear_projection, transformer_layers_per_block: mid_transformer_layers_per_block, ..Default::default() }; let mid_block = UNetMidBlock2DCrossAttn::new( vs.pp("mid_block"), bl_channels, Some(time_embed_dim), use_flash_attn, mid_cfg, )?; let vs_ub = vs.pp("up_blocks"); let up_blocks = (0..n_blocks) .map(|i| { let BlockConfig { out_channels, use_cross_attn, attention_head_dim, } = config.blocks[n_blocks - 1 - i]; // Enable automatic attention slicing if the config sliced_attention_size is set to 0. let sliced_attention_size = match config.sliced_attention_size { Some(0) => Some(attention_head_dim / 2), _ => config.sliced_attention_size, }; let prev_out_channels = if i > 0 { config.blocks[n_blocks - i].out_channels } else { bl_channels }; let in_channels = { let index = if i == n_blocks - 1 { 0 } else { n_blocks - i - 2 }; config.blocks[index].out_channels }; let ub_cfg = UpBlock2DConfig { num_layers: config.layers_per_block + 1, resnet_eps: config.norm_eps, resnet_groups: config.norm_num_groups, add_upsample: i < n_blocks - 1, ..Default::default() }; if let Some(transformer_layers_per_block) = use_cross_attn { let config = CrossAttnUpBlock2DConfig { upblock: ub_cfg, attn_num_head_channels: attention_head_dim, cross_attention_dim: config.cross_attention_dim, sliced_attention_size, use_linear_projection: config.use_linear_projection, transformer_layers_per_block, }; let block = CrossAttnUpBlock2D::new( vs_ub.pp(i.to_string()), in_channels, prev_out_channels, out_channels, Some(time_embed_dim), use_flash_attn, config, )?; Ok(UNetUpBlock::CrossAttn(block)) } else { let block = UpBlock2D::new( vs_ub.pp(i.to_string()), in_channels, prev_out_channels, out_channels, Some(time_embed_dim), ub_cfg, )?; Ok(UNetUpBlock::Basic(block)) } }) .collect::>>()?; let conv_norm_out = nn::group_norm( config.norm_num_groups, b_channels, config.norm_eps, vs.pp("conv_norm_out"), )?; let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?; let span = tracing::span!(tracing::Level::TRACE, "unet2d"); Ok(Self { conv_in, time_proj, time_embedding, down_blocks, mid_block, up_blocks, conv_norm_out, conv_out, span, config, }) } pub fn forward( &self, xs: &Tensor, timestep: f64, encoder_hidden_states: &Tensor, ) -> Result { let _enter = self.span.enter(); self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None) } pub fn forward_with_additional_residuals( &self, xs: &Tensor, timestep: f64, encoder_hidden_states: &Tensor, down_block_additional_residuals: Option<&[Tensor]>, mid_block_additional_residual: Option<&Tensor>, ) -> Result { let (bsize, _channels, height, width) = xs.dims4()?; let device = xs.device(); let n_blocks = self.config.blocks.len(); let num_upsamplers = n_blocks - 1; let default_overall_up_factor = 2usize.pow(num_upsamplers as u32); let forward_upsample_size = height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0; // 0. center input if necessary let xs = if self.config.center_input_sample { ((xs * 2.0)? - 1.0)? } else { xs.clone() }; // 1. time let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?; let emb = self.time_proj.forward(&emb)?; let emb = self.time_embedding.forward(&emb)?; // 2. pre-process let xs = self.conv_in.forward(&xs)?; // 3. down let mut down_block_res_xs = vec![xs.clone()]; let mut xs = xs; for down_block in self.down_blocks.iter() { let (_xs, res_xs) = match down_block { UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?, UNetDownBlock::CrossAttn(b) => { b.forward(&xs, Some(&emb), Some(encoder_hidden_states))? } }; down_block_res_xs.extend(res_xs); xs = _xs; } let new_down_block_res_xs = if let Some(down_block_additional_residuals) = down_block_additional_residuals { let mut v = vec![]; // A previous version of this code had a bug because of the addition being made // in place via += hence modifying the input of the mid block. for (i, residuals) in down_block_additional_residuals.iter().enumerate() { v.push((&down_block_res_xs[i] + residuals)?) } v } else { down_block_res_xs }; let mut down_block_res_xs = new_down_block_res_xs; // 4. mid let xs = self .mid_block .forward(&xs, Some(&emb), Some(encoder_hidden_states))?; let xs = match mid_block_additional_residual { None => xs, Some(m) => (m + xs)?, }; // 5. up let mut xs = xs; let mut upsample_size = None; for (i, up_block) in self.up_blocks.iter().enumerate() { let n_resnets = match up_block { UNetUpBlock::Basic(b) => b.resnets.len(), UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(), }; let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets); if i < n_blocks - 1 && forward_upsample_size { let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?; upsample_size = Some((h, w)) } xs = match up_block { UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?, UNetUpBlock::CrossAttn(b) => b.forward( &xs, &res_xs, Some(&emb), upsample_size, Some(encoder_hidden_states), )?, }; } // 6. post-process let xs = self.conv_norm_out.forward(&xs)?; let xs = nn::ops::silu(&xs)?; self.conv_out.forward(&xs) } }