summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/resnet.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion/resnet.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/resnet.rs129
1 files changed, 129 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs
new file mode 100644
index 00000000..b6696083
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/resnet.rs
@@ -0,0 +1,129 @@
+#![allow(dead_code)]
+//! ResNet Building Blocks
+//!
+//! Some Residual Network blocks used in UNet models.
+//!
+//! Denoising Diffusion Implicit Models, K. He and al, 2015.
+//! https://arxiv.org/abs/1512.03385
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+
+/// Configuration for a ResNet block.
+#[derive(Debug, Clone, Copy)]
+pub struct ResnetBlock2DConfig {
+ /// The number of output channels, defaults to the number of input channels.
+ pub out_channels: Option<usize>,
+ pub temb_channels: Option<usize>,
+ /// The number of groups to use in group normalization.
+ pub groups: usize,
+ pub groups_out: Option<usize>,
+ /// The epsilon to be used in the group normalization operations.
+ pub eps: f64,
+ /// Whether to use a 2D convolution in the skip connection. When using None,
+ /// such a convolution is used if the number of input channels is different from
+ /// the number of output channels.
+ pub use_in_shortcut: Option<bool>,
+ // non_linearity: silu
+ /// The final output is scaled by dividing by this value.
+ pub output_scale_factor: f64,
+}
+
+impl Default for ResnetBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ out_channels: None,
+ temb_channels: Some(512),
+ groups: 32,
+ groups_out: None,
+ eps: 1e-6,
+ use_in_shortcut: None,
+ output_scale_factor: 1.,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct ResnetBlock2D {
+ norm1: nn::GroupNorm,
+ conv1: nn::Conv2d,
+ norm2: nn::GroupNorm,
+ conv2: nn::Conv2d,
+ time_emb_proj: Option<nn::Linear>,
+ conv_shortcut: Option<nn::Conv2d>,
+ config: ResnetBlock2DConfig,
+}
+
+impl ResnetBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ config: ResnetBlock2DConfig,
+ ) -> Result<Self> {
+ let out_channels = config.out_channels.unwrap_or(in_channels);
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
+ let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
+ let groups_out = config.groups_out.unwrap_or(config.groups);
+ let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
+ let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
+ let use_in_shortcut = config
+ .use_in_shortcut
+ .unwrap_or(in_channels != out_channels);
+ let conv_shortcut = if use_in_shortcut {
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 0,
+ };
+ Some(nn::conv2d(
+ in_channels,
+ out_channels,
+ 1,
+ conv_cfg,
+ vs.pp("conv_shortcut"),
+ )?)
+ } else {
+ None
+ };
+ let time_emb_proj = match config.temb_channels {
+ None => None,
+ Some(temb_channels) => Some(nn::linear(
+ temb_channels,
+ out_channels,
+ vs.pp("time_emb_proj"),
+ )?),
+ };
+ Ok(Self {
+ norm1,
+ conv1,
+ norm2,
+ conv2,
+ time_emb_proj,
+ config,
+ conv_shortcut,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
+ let shortcut_xs = match &self.conv_shortcut {
+ Some(conv_shortcut) => conv_shortcut.forward(xs)?,
+ None => xs.clone(),
+ };
+ let xs = self.norm1.forward(xs)?;
+ let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
+ let xs = match (temb, &self.time_emb_proj) {
+ (Some(temb), Some(time_emb_proj)) => time_emb_proj
+ .forward(&nn::ops::silu(temb)?)?
+ .unsqueeze(D::Minus1)?
+ .unsqueeze(D::Minus1)?
+ .add(&xs)?,
+ _ => xs,
+ };
+ let xs = self
+ .conv2
+ .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
+ (shortcut_xs + xs)? / self.config.output_scale_factor
+ }
+}