diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d_blocks.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 809 |
1 files changed, 809 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs new file mode 100644 index 00000000..8dd6cf26 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -0,0 +1,809 @@ +#![allow(dead_code)] +//! 2D UNet Building Blocks +//! +use crate::attention::{ + AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, +}; +use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; +use candle::{Result, Tensor}; +use candle_nn as nn; + +#[derive(Debug)] +struct Downsample2D { + conv: Option<nn::Conv2d>, + padding: usize, +} + +impl Downsample2D { + fn new( + vs: nn::VarBuilder, + in_channels: usize, + use_conv: bool, + out_channels: usize, + padding: usize, + ) -> Result<Self> { + let conv = if use_conv { + let config = nn::Conv2dConfig { stride: 2, padding }; + let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; + Some(conv) + } else { + None + }; + Ok(Downsample2D { conv, padding }) + } +} + +impl Downsample2D { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match &self.conv { + None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None), + Some(conv) => { + if self.padding == 0 { + let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?; + conv.forward(&xs) + } else { + conv.forward(xs) + } + } + } + } +} + +// This does not support the conv-transpose mode. +#[derive(Debug)] +struct Upsample2D { + conv: nn::Conv2d, +} + +impl Upsample2D { + fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> { + let config = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; + Ok(Self { conv }) + } +} + +impl Upsample2D { + fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> { + let xs = match size { + None => { + // The following does not work and it's tricky to pass no fixed + // dimensions so hack our way around this. + // xs.upsample_nearest2d(&[], Some(2.), Some(2.) + let (_bsize, _channels, _h, _w) = xs.dims4()?; + crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.)) + } + Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None), + }; + self.conv.forward(&xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DownEncoderBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_downsample: bool, + pub downsample_padding: usize, +} + +impl Default for DownEncoderBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_downsample: true, + downsample_padding: 1, + } + } +} + +#[derive(Debug)] +pub struct DownEncoderBlock2D { + resnets: Vec<ResnetBlock2D>, + downsampler: Option<Downsample2D>, + pub config: DownEncoderBlock2DConfig, +} + +impl DownEncoderBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: DownEncoderBlock2DConfig, + ) -> Result<Self> { + let resnets: Vec<_> = { + let vs = vs.pp("resnets"); + let conv_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + out_channels: Some(out_channels), + groups: config.resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels: None, + ..Default::default() + }; + (0..(config.num_layers)) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + }) + .collect::<Result<Vec<_>>>()? + }; + let downsampler = if config.add_downsample { + let downsample = Downsample2D::new( + vs.pp("downsamplers").pp("0"), + out_channels, + true, + out_channels, + config.downsample_padding, + )?; + Some(downsample) + } else { + None + }; + Ok(Self { + resnets, + downsampler, + config, + }) + } +} + +impl DownEncoderBlock2D { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, None)? + } + match &self.downsampler { + Some(downsampler) => downsampler.forward(&xs), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UpDecoderBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_upsample: bool, +} + +impl Default for UpDecoderBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_upsample: true, + } + } +} + +#[derive(Debug)] +pub struct UpDecoderBlock2D { + resnets: Vec<ResnetBlock2D>, + upsampler: Option<Upsample2D>, + pub config: UpDecoderBlock2DConfig, +} + +impl UpDecoderBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: UpDecoderBlock2DConfig, + ) -> Result<Self> { + let resnets: Vec<_> = { + let vs = vs.pp("resnets"); + let conv_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + eps: config.resnet_eps, + groups: config.resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels: None, + ..Default::default() + }; + (0..(config.num_layers)) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + }) + .collect::<Result<Vec<_>>>()? + }; + let upsampler = if config.add_upsample { + let upsample = + Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?; + Some(upsample) + } else { + None + }; + Ok(Self { + resnets, + upsampler, + config, + }) + } +} + +impl UpDecoderBlock2D { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, None)? + } + match &self.upsampler { + Some(upsampler) => upsampler.forward(&xs, None), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UNetMidBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: Option<usize>, + pub attn_num_head_channels: Option<usize>, + // attention_type "default" + pub output_scale_factor: f64, +} + +impl Default for UNetMidBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: Some(32), + attn_num_head_channels: Some(1), + output_scale_factor: 1., + } + } +} + +#[derive(Debug)] +pub struct UNetMidBlock2D { + resnet: ResnetBlock2D, + attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>, + pub config: UNetMidBlock2DConfig, +} + +impl UNetMidBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + temb_channels: Option<usize>, + config: UNetMidBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let vs_attns = vs.pp("attentions"); + let resnet_groups = config + .resnet_groups + .unwrap_or_else(|| usize::min(in_channels / 4, 32)); + let resnet_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + groups: resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?; + let attn_cfg = AttentionBlockConfig { + num_head_channels: config.attn_num_head_channels, + num_groups: resnet_groups, + rescale_output_factor: config.output_scale_factor, + eps: config.resnet_eps, + }; + let mut attn_resnets = vec![]; + for index in 0..config.num_layers { + let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?; + let resnet = ResnetBlock2D::new( + vs_resnets.pp(&(index + 1).to_string()), + in_channels, + resnet_cfg, + )?; + attn_resnets.push((attn, resnet)) + } + Ok(Self { + resnet, + attn_resnets, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> { + let mut xs = self.resnet.forward(xs, temb)?; + for (attn, resnet) in self.attn_resnets.iter() { + xs = resnet.forward(&attn.forward(&xs)?, temb)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UNetMidBlock2DCrossAttnConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: Option<usize>, + pub attn_num_head_channels: usize, + // attention_type "default" + pub output_scale_factor: f64, + pub cross_attn_dim: usize, + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for UNetMidBlock2DCrossAttnConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: Some(32), + attn_num_head_channels: 1, + output_scale_factor: 1., + cross_attn_dim: 1280, + sliced_attention_size: None, // Sliced attention disabled + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub struct UNetMidBlock2DCrossAttn { + resnet: ResnetBlock2D, + attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>, + pub config: UNetMidBlock2DCrossAttnConfig, +} + +impl UNetMidBlock2DCrossAttn { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + temb_channels: Option<usize>, + config: UNetMidBlock2DCrossAttnConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let vs_attns = vs.pp("attentions"); + let resnet_groups = config + .resnet_groups + .unwrap_or_else(|| usize::min(in_channels / 4, 32)); + let resnet_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + groups: resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?; + let n_heads = config.attn_num_head_channels; + let attn_cfg = SpatialTransformerConfig { + depth: 1, + num_groups: resnet_groups, + context_dim: Some(config.cross_attn_dim), + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let mut attn_resnets = vec![]; + for index in 0..config.num_layers { + let attn = SpatialTransformer::new( + vs_attns.pp(&index.to_string()), + in_channels, + n_heads, + in_channels / n_heads, + attn_cfg, + )?; + let resnet = ResnetBlock2D::new( + vs_resnets.pp(&(index + 1).to_string()), + in_channels, + resnet_cfg, + )?; + attn_resnets.push((attn, resnet)) + } + Ok(Self { + resnet, + attn_resnets, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + temb: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let mut xs = self.resnet.forward(xs, temb)?; + for (attn, resnet) in self.attn_resnets.iter() { + xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DownBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + // resnet_time_scale_shift: "default" + // resnet_act_fn: "swish" + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_downsample: bool, + pub downsample_padding: usize, +} + +impl Default for DownBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_downsample: true, + downsample_padding: 1, + } + } +} + +#[derive(Debug)] +pub struct DownBlock2D { + resnets: Vec<ResnetBlock2D>, + downsampler: Option<Downsample2D>, + pub config: DownBlock2DConfig, +} + +impl DownBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: DownBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let resnet_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + eps: config.resnet_eps, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnets = (0..config.num_layers) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + }) + .collect::<Result<Vec<_>>>()?; + let downsampler = if config.add_downsample { + let downsampler = Downsample2D::new( + vs.pp("downsamplers").pp("0"), + out_channels, + true, + out_channels, + config.downsample_padding, + )?; + Some(downsampler) + } else { + None + }; + Ok(Self { + resnets, + downsampler, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> { + let mut xs = xs.clone(); + let mut output_states = vec![]; + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, temb)?; + output_states.push(xs.clone()); + } + let xs = match &self.downsampler { + Some(downsampler) => { + let xs = downsampler.forward(&xs)?; + output_states.push(xs.clone()); + xs + } + None => xs, + }; + Ok((xs, output_states)) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct CrossAttnDownBlock2DConfig { + pub downblock: DownBlock2DConfig, + pub attn_num_head_channels: usize, + pub cross_attention_dim: usize, + // attention_type: "default" + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for CrossAttnDownBlock2DConfig { + fn default() -> Self { + Self { + downblock: Default::default(), + attn_num_head_channels: 1, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub struct CrossAttnDownBlock2D { + downblock: DownBlock2D, + attentions: Vec<SpatialTransformer>, + pub config: CrossAttnDownBlock2DConfig, +} + +impl CrossAttnDownBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: CrossAttnDownBlock2DConfig, + ) -> Result<Self> { + let downblock = DownBlock2D::new( + vs.clone(), + in_channels, + out_channels, + temb_channels, + config.downblock, + )?; + let n_heads = config.attn_num_head_channels; + let cfg = SpatialTransformerConfig { + depth: 1, + context_dim: Some(config.cross_attention_dim), + num_groups: config.downblock.resnet_groups, + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let vs_attn = vs.pp("attentions"); + let attentions = (0..config.downblock.num_layers) + .map(|i| { + SpatialTransformer::new( + vs_attn.pp(&i.to_string()), + out_channels, + n_heads, + out_channels / n_heads, + cfg, + ) + }) + .collect::<Result<Vec<_>>>()?; + Ok(Self { + downblock, + attentions, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + temb: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<(Tensor, Vec<Tensor>)> { + let mut output_states = vec![]; + let mut xs = xs.clone(); + for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) { + xs = resnet.forward(&xs, temb)?; + xs = attn.forward(&xs, encoder_hidden_states)?; + output_states.push(xs.clone()); + } + let xs = match &self.downblock.downsampler { + Some(downsampler) => { + let xs = downsampler.forward(&xs)?; + output_states.push(xs.clone()); + xs + } + None => xs, + }; + Ok((xs, output_states)) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UpBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + // resnet_time_scale_shift: "default" + // resnet_act_fn: "swish" + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_upsample: bool, +} + +impl Default for UpBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_upsample: true, + } + } +} + +#[derive(Debug)] +pub struct UpBlock2D { + pub resnets: Vec<ResnetBlock2D>, + upsampler: Option<Upsample2D>, + pub config: UpBlock2DConfig, +} + +impl UpBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + prev_output_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: UpBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let resnet_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + temb_channels, + eps: config.resnet_eps, + output_scale_factor: config.output_scale_factor, + ..Default::default() + }; + let resnets = (0..config.num_layers) + .map(|i| { + let res_skip_channels = if i == config.num_layers - 1 { + in_channels + } else { + out_channels + }; + let resnet_in_channels = if i == 0 { + prev_output_channels + } else { + out_channels + }; + let in_channels = resnet_in_channels + res_skip_channels; + ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + }) + .collect::<Result<Vec<_>>>()?; + let upsampler = if config.add_upsample { + let upsampler = + Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?; + Some(upsampler) + } else { + None + }; + Ok(Self { + resnets, + upsampler, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + res_xs: &[Tensor], + temb: Option<&Tensor>, + upsample_size: Option<(usize, usize)>, + ) -> Result<Tensor> { + let mut xs = xs.clone(); + for (index, resnet) in self.resnets.iter().enumerate() { + xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; + xs = resnet.forward(&xs, temb)?; + } + match &self.upsampler { + Some(upsampler) => upsampler.forward(&xs, upsample_size), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct CrossAttnUpBlock2DConfig { + pub upblock: UpBlock2DConfig, + pub attn_num_head_channels: usize, + pub cross_attention_dim: usize, + // attention_type: "default" + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for CrossAttnUpBlock2DConfig { + fn default() -> Self { + Self { + upblock: Default::default(), + attn_num_head_channels: 1, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub struct CrossAttnUpBlock2D { + pub upblock: UpBlock2D, + pub attentions: Vec<SpatialTransformer>, + pub config: CrossAttnUpBlock2DConfig, +} + +impl CrossAttnUpBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + prev_output_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: CrossAttnUpBlock2DConfig, + ) -> Result<Self> { + let upblock = UpBlock2D::new( + vs.clone(), + in_channels, + prev_output_channels, + out_channels, + temb_channels, + config.upblock, + )?; + let n_heads = config.attn_num_head_channels; + let cfg = SpatialTransformerConfig { + depth: 1, + context_dim: Some(config.cross_attention_dim), + num_groups: config.upblock.resnet_groups, + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let vs_attn = vs.pp("attentions"); + let attentions = (0..config.upblock.num_layers) + .map(|i| { + SpatialTransformer::new( + vs_attn.pp(&i.to_string()), + out_channels, + n_heads, + out_channels / n_heads, + cfg, + ) + }) + .collect::<Result<Vec<_>>>()?; + Ok(Self { + upblock, + attentions, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + res_xs: &[Tensor], + temb: Option<&Tensor>, + upsample_size: Option<(usize, usize)>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let mut xs = xs.clone(); + for (index, resnet) in self.upblock.resnets.iter().enumerate() { + xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; + xs = resnet.forward(&xs, temb)?; + xs = self.attentions[index].forward(&xs, encoder_hidden_states)?; + } + match &self.upblock.upsampler { + Some(upsampler) => upsampler.forward(&xs, upsample_size), + None => Ok(xs), + } + } +} |