summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/stable_diffusion/unet_2d.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/stable_diffusion/unet_2d.rs')
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d.rs401
1 files changed, 401 insertions, 0 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
new file mode 100644
index 00000000..a3ed136e
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
@@ -0,0 +1,401 @@
+//! 2D UNet Denoising Models
+//!
+//! The 2D Unet models take as input a noisy sample and the current diffusion
+//! timestep and return a denoised version of the input.
+use super::embeddings::{TimestepEmbedding, Timesteps};
+use super::unet_2d_blocks::*;
+use super::utils::{conv2d, Conv2d};
+use candle::{Result, Tensor};
+use candle_nn as nn;
+use candle_nn::Module;
+
+#[derive(Debug, Clone, Copy)]
+pub struct BlockConfig {
+ pub out_channels: usize,
+ /// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the
+ /// number of transformer blocks to be used.
+ pub use_cross_attn: Option<usize>,
+ pub attention_head_dim: usize,
+}
+
+#[derive(Debug, Clone)]
+pub struct UNet2DConditionModelConfig {
+ pub center_input_sample: bool,
+ pub flip_sin_to_cos: bool,
+ pub freq_shift: f64,
+ pub blocks: Vec<BlockConfig>,
+ pub layers_per_block: usize,
+ pub downsample_padding: usize,
+ pub mid_block_scale_factor: f64,
+ pub norm_num_groups: usize,
+ pub norm_eps: f64,
+ pub cross_attention_dim: usize,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for UNet2DConditionModelConfig {
+ fn default() -> Self {
+ Self {
+ center_input_sample: false,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ blocks: vec![
+ BlockConfig {
+ out_channels: 320,
+ use_cross_attn: Some(1),
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 640,
+ use_cross_attn: Some(1),
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 1280,
+ use_cross_attn: Some(1),
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 1280,
+ use_cross_attn: None,
+ attention_head_dim: 8,
+ },
+ ],
+ layers_per_block: 2,
+ downsample_padding: 1,
+ mid_block_scale_factor: 1.,
+ norm_num_groups: 32,
+ norm_eps: 1e-5,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) enum UNetDownBlock {
+ Basic(DownBlock2D),
+ CrossAttn(CrossAttnDownBlock2D),
+}
+
+#[derive(Debug)]
+enum UNetUpBlock {
+ Basic(UpBlock2D),
+ CrossAttn(CrossAttnUpBlock2D),
+}
+
+#[derive(Debug)]
+pub struct UNet2DConditionModel {
+ conv_in: Conv2d,
+ time_proj: Timesteps,
+ time_embedding: TimestepEmbedding,
+ down_blocks: Vec<UNetDownBlock>,
+ mid_block: UNetMidBlock2DCrossAttn,
+ up_blocks: Vec<UNetUpBlock>,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: Conv2d,
+ span: tracing::Span,
+ config: UNet2DConditionModelConfig,
+}
+
+impl UNet2DConditionModel {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ use_flash_attn: bool,
+ config: UNet2DConditionModelConfig,
+ ) -> Result<Self> {
+ let n_blocks = config.blocks.len();
+ let b_channels = config.blocks[0].out_channels;
+ let bl_channels = config.blocks.last().unwrap().out_channels;
+ let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
+ let time_embed_dim = b_channels * 4;
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
+
+ let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
+ let time_embedding =
+ TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
+
+ let vs_db = vs.pp("down_blocks");
+ let down_blocks = (0..n_blocks)
+ .map(|i| {
+ let BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ } = config.blocks[i];
+
+ // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
+ let sliced_attention_size = match config.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ _ => config.sliced_attention_size,
+ };
+
+ let in_channels = if i > 0 {
+ config.blocks[i - 1].out_channels
+ } else {
+ b_channels
+ };
+ let db_cfg = DownBlock2DConfig {
+ num_layers: config.layers_per_block,
+ resnet_eps: config.norm_eps,
+ resnet_groups: config.norm_num_groups,
+ add_downsample: i < n_blocks - 1,
+ downsample_padding: config.downsample_padding,
+ ..Default::default()
+ };
+ if let Some(transformer_layers_per_block) = use_cross_attn {
+ let config = CrossAttnDownBlock2DConfig {
+ downblock: db_cfg,
+ attn_num_head_channels: attention_head_dim,
+ cross_attention_dim: config.cross_attention_dim,
+ sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block,
+ };
+ let block = CrossAttnDownBlock2D::new(
+ vs_db.pp(&i.to_string()),
+ in_channels,
+ out_channels,
+ Some(time_embed_dim),
+ use_flash_attn,
+ config,
+ )?;
+ Ok(UNetDownBlock::CrossAttn(block))
+ } else {
+ let block = DownBlock2D::new(
+ vs_db.pp(&i.to_string()),
+ in_channels,
+ out_channels,
+ Some(time_embed_dim),
+ db_cfg,
+ )?;
+ Ok(UNetDownBlock::Basic(block))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ // https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462
+ let mid_transformer_layers_per_block = match config.blocks.last() {
+ None => 1,
+ Some(block) => block.use_cross_attn.unwrap_or(1),
+ };
+ let mid_cfg = UNetMidBlock2DCrossAttnConfig {
+ resnet_eps: config.norm_eps,
+ output_scale_factor: config.mid_block_scale_factor,
+ cross_attn_dim: config.cross_attention_dim,
+ attn_num_head_channels: bl_attention_head_dim,
+ resnet_groups: Some(config.norm_num_groups),
+ use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block: mid_transformer_layers_per_block,
+ ..Default::default()
+ };
+
+ let mid_block = UNetMidBlock2DCrossAttn::new(
+ vs.pp("mid_block"),
+ bl_channels,
+ Some(time_embed_dim),
+ use_flash_attn,
+ mid_cfg,
+ )?;
+
+ let vs_ub = vs.pp("up_blocks");
+ let up_blocks = (0..n_blocks)
+ .map(|i| {
+ let BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ } = config.blocks[n_blocks - 1 - i];
+
+ // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
+ let sliced_attention_size = match config.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ _ => config.sliced_attention_size,
+ };
+
+ let prev_out_channels = if i > 0 {
+ config.blocks[n_blocks - i].out_channels
+ } else {
+ bl_channels
+ };
+ let in_channels = {
+ let index = if i == n_blocks - 1 {
+ 0
+ } else {
+ n_blocks - i - 2
+ };
+ config.blocks[index].out_channels
+ };
+ let ub_cfg = UpBlock2DConfig {
+ num_layers: config.layers_per_block + 1,
+ resnet_eps: config.norm_eps,
+ resnet_groups: config.norm_num_groups,
+ add_upsample: i < n_blocks - 1,
+ ..Default::default()
+ };
+ if let Some(transformer_layers_per_block) = use_cross_attn {
+ let config = CrossAttnUpBlock2DConfig {
+ upblock: ub_cfg,
+ attn_num_head_channels: attention_head_dim,
+ cross_attention_dim: config.cross_attention_dim,
+ sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block,
+ };
+ let block = CrossAttnUpBlock2D::new(
+ vs_ub.pp(&i.to_string()),
+ in_channels,
+ prev_out_channels,
+ out_channels,
+ Some(time_embed_dim),
+ use_flash_attn,
+ config,
+ )?;
+ Ok(UNetUpBlock::CrossAttn(block))
+ } else {
+ let block = UpBlock2D::new(
+ vs_ub.pp(&i.to_string()),
+ in_channels,
+ prev_out_channels,
+ out_channels,
+ Some(time_embed_dim),
+ ub_cfg,
+ )?;
+ Ok(UNetUpBlock::Basic(block))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ b_channels,
+ config.norm_eps,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "unet2d");
+ Ok(Self {
+ conv_in,
+ time_proj,
+ time_embedding,
+ down_blocks,
+ mid_block,
+ up_blocks,
+ conv_norm_out,
+ conv_out,
+ span,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ timestep: f64,
+ encoder_hidden_states: &Tensor,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
+ }
+
+ pub fn forward_with_additional_residuals(
+ &self,
+ xs: &Tensor,
+ timestep: f64,
+ encoder_hidden_states: &Tensor,
+ down_block_additional_residuals: Option<&[Tensor]>,
+ mid_block_additional_residual: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (bsize, _channels, height, width) = xs.dims4()?;
+ let device = xs.device();
+ let n_blocks = self.config.blocks.len();
+ let num_upsamplers = n_blocks - 1;
+ let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
+ let forward_upsample_size =
+ height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
+ // 0. center input if necessary
+ let xs = if self.config.center_input_sample {
+ ((xs * 2.0)? - 1.0)?
+ } else {
+ xs.clone()
+ };
+ // 1. time
+ let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?;
+ let emb = self.time_proj.forward(&emb)?;
+ let emb = self.time_embedding.forward(&emb)?;
+ // 2. pre-process
+ let xs = self.conv_in.forward(&xs)?;
+ // 3. down
+ let mut down_block_res_xs = vec![xs.clone()];
+ let mut xs = xs;
+ for down_block in self.down_blocks.iter() {
+ let (_xs, res_xs) = match down_block {
+ UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
+ UNetDownBlock::CrossAttn(b) => {
+ b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
+ }
+ };
+ down_block_res_xs.extend(res_xs);
+ xs = _xs;
+ }
+
+ let new_down_block_res_xs =
+ if let Some(down_block_additional_residuals) = down_block_additional_residuals {
+ let mut v = vec![];
+ // A previous version of this code had a bug because of the addition being made
+ // in place via += hence modifying the input of the mid block.
+ for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
+ v.push((&down_block_res_xs[i] + residuals)?)
+ }
+ v
+ } else {
+ down_block_res_xs
+ };
+ let mut down_block_res_xs = new_down_block_res_xs;
+
+ // 4. mid
+ let xs = self
+ .mid_block
+ .forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
+ let xs = match mid_block_additional_residual {
+ None => xs,
+ Some(m) => (m + xs)?,
+ };
+ // 5. up
+ let mut xs = xs;
+ let mut upsample_size = None;
+ for (i, up_block) in self.up_blocks.iter().enumerate() {
+ let n_resnets = match up_block {
+ UNetUpBlock::Basic(b) => b.resnets.len(),
+ UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
+ };
+ let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
+ if i < n_blocks - 1 && forward_upsample_size {
+ let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
+ upsample_size = Some((h, w))
+ }
+ xs = match up_block {
+ UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
+ UNetUpBlock::CrossAttn(b) => b.forward(
+ &xs,
+ &res_xs,
+ Some(&emb),
+ upsample_size,
+ Some(encoder_hidden_states),
+ )?,
+ };
+ }
+ // 6. post-process
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}