diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/resnet.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/resnet.rs | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs new file mode 100644 index 00000000..b6696083 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/resnet.rs @@ -0,0 +1,129 @@ +#![allow(dead_code)] +//! ResNet Building Blocks +//! +//! Some Residual Network blocks used in UNet models. +//! +//! Denoising Diffusion Implicit Models, K. He and al, 2015. +//! https://arxiv.org/abs/1512.03385 +use candle::{Result, Tensor, D}; +use candle_nn as nn; + +/// Configuration for a ResNet block. +#[derive(Debug, Clone, Copy)] +pub struct ResnetBlock2DConfig { + /// The number of output channels, defaults to the number of input channels. + pub out_channels: Option<usize>, + pub temb_channels: Option<usize>, + /// The number of groups to use in group normalization. + pub groups: usize, + pub groups_out: Option<usize>, + /// The epsilon to be used in the group normalization operations. + pub eps: f64, + /// Whether to use a 2D convolution in the skip connection. When using None, + /// such a convolution is used if the number of input channels is different from + /// the number of output channels. + pub use_in_shortcut: Option<bool>, + // non_linearity: silu + /// The final output is scaled by dividing by this value. + pub output_scale_factor: f64, +} + +impl Default for ResnetBlock2DConfig { + fn default() -> Self { + Self { + out_channels: None, + temb_channels: Some(512), + groups: 32, + groups_out: None, + eps: 1e-6, + use_in_shortcut: None, + output_scale_factor: 1., + } + } +} + +#[derive(Debug)] +pub struct ResnetBlock2D { + norm1: nn::GroupNorm, + conv1: nn::Conv2d, + norm2: nn::GroupNorm, + conv2: nn::Conv2d, + time_emb_proj: Option<nn::Linear>, + conv_shortcut: Option<nn::Conv2d>, + config: ResnetBlock2DConfig, +} + +impl ResnetBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + config: ResnetBlock2DConfig, + ) -> Result<Self> { + let out_channels = config.out_channels.unwrap_or(in_channels); + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 1, + }; + let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?; + let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?; + let groups_out = config.groups_out.unwrap_or(config.groups); + let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?; + let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?; + let use_in_shortcut = config + .use_in_shortcut + .unwrap_or(in_channels != out_channels); + let conv_shortcut = if use_in_shortcut { + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 0, + }; + Some(nn::conv2d( + in_channels, + out_channels, + 1, + conv_cfg, + vs.pp("conv_shortcut"), + )?) + } else { + None + }; + let time_emb_proj = match config.temb_channels { + None => None, + Some(temb_channels) => Some(nn::linear( + temb_channels, + out_channels, + vs.pp("time_emb_proj"), + )?), + }; + Ok(Self { + norm1, + conv1, + norm2, + conv2, + time_emb_proj, + config, + conv_shortcut, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> { + let shortcut_xs = match &self.conv_shortcut { + Some(conv_shortcut) => conv_shortcut.forward(xs)?, + None => xs.clone(), + }; + let xs = self.norm1.forward(xs)?; + let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?; + let xs = match (temb, &self.time_emb_proj) { + (Some(temb), Some(time_emb_proj)) => time_emb_proj + .forward(&nn::ops::silu(temb)?)? + .unsqueeze(D::Minus1)? + .unsqueeze(D::Minus1)? + .add(&xs)?, + _ => xs, + }; + let xs = self + .conv2 + .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?; + (shortcut_xs + xs)? / self.config.output_scale_factor + } +} |