summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d_blocks.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs809
1 files changed, 809 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
new file mode 100644
index 00000000..8dd6cf26
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -0,0 +1,809 @@
+#![allow(dead_code)]
+//! 2D UNet Building Blocks
+//!
+use crate::attention::{
+ AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
+};
+use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
+use candle::{Result, Tensor};
+use candle_nn as nn;
+
+#[derive(Debug)]
+struct Downsample2D {
+ conv: Option<nn::Conv2d>,
+ padding: usize,
+}
+
+impl Downsample2D {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ use_conv: bool,
+ out_channels: usize,
+ padding: usize,
+ ) -> Result<Self> {
+ let conv = if use_conv {
+ let config = nn::Conv2dConfig { stride: 2, padding };
+ let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
+ Some(conv)
+ } else {
+ None
+ };
+ Ok(Downsample2D { conv, padding })
+ }
+}
+
+impl Downsample2D {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match &self.conv {
+ None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
+ Some(conv) => {
+ if self.padding == 0 {
+ let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?;
+ conv.forward(&xs)
+ } else {
+ conv.forward(xs)
+ }
+ }
+ }
+ }
+}
+
+// This does not support the conv-transpose mode.
+#[derive(Debug)]
+struct Upsample2D {
+ conv: nn::Conv2d,
+}
+
+impl Upsample2D {
+ fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
+ let config = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
+ Ok(Self { conv })
+ }
+}
+
+impl Upsample2D {
+ fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
+ let xs = match size {
+ None => {
+ // The following does not work and it's tricky to pass no fixed
+ // dimensions so hack our way around this.
+ // xs.upsample_nearest2d(&[], Some(2.), Some(2.)
+ let (_bsize, _channels, _h, _w) = xs.dims4()?;
+ crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.))
+ }
+ Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None),
+ };
+ self.conv.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DownEncoderBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_downsample: bool,
+ pub downsample_padding: usize,
+}
+
+impl Default for DownEncoderBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_downsample: true,
+ downsample_padding: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct DownEncoderBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ downsampler: Option<Downsample2D>,
+ pub config: DownEncoderBlock2DConfig,
+}
+
+impl DownEncoderBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: DownEncoderBlock2DConfig,
+ ) -> Result<Self> {
+ let resnets: Vec<_> = {
+ let vs = vs.pp("resnets");
+ let conv_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ out_channels: Some(out_channels),
+ groups: config.resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels: None,
+ ..Default::default()
+ };
+ (0..(config.num_layers))
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
+ let downsampler = if config.add_downsample {
+ let downsample = Downsample2D::new(
+ vs.pp("downsamplers").pp("0"),
+ out_channels,
+ true,
+ out_channels,
+ config.downsample_padding,
+ )?;
+ Some(downsample)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ downsampler,
+ config,
+ })
+ }
+}
+
+impl DownEncoderBlock2D {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, None)?
+ }
+ match &self.downsampler {
+ Some(downsampler) => downsampler.forward(&xs),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UpDecoderBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_upsample: bool,
+}
+
+impl Default for UpDecoderBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_upsample: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UpDecoderBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ upsampler: Option<Upsample2D>,
+ pub config: UpDecoderBlock2DConfig,
+}
+
+impl UpDecoderBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: UpDecoderBlock2DConfig,
+ ) -> Result<Self> {
+ let resnets: Vec<_> = {
+ let vs = vs.pp("resnets");
+ let conv_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ eps: config.resnet_eps,
+ groups: config.resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels: None,
+ ..Default::default()
+ };
+ (0..(config.num_layers))
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
+ let upsampler = if config.add_upsample {
+ let upsample =
+ Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
+ Some(upsample)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ upsampler,
+ config,
+ })
+ }
+}
+
+impl UpDecoderBlock2D {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, None)?
+ }
+ match &self.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, None),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UNetMidBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: Option<usize>,
+ pub attn_num_head_channels: Option<usize>,
+ // attention_type "default"
+ pub output_scale_factor: f64,
+}
+
+impl Default for UNetMidBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: Some(32),
+ attn_num_head_channels: Some(1),
+ output_scale_factor: 1.,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UNetMidBlock2D {
+ resnet: ResnetBlock2D,
+ attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
+ pub config: UNetMidBlock2DConfig,
+}
+
+impl UNetMidBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ temb_channels: Option<usize>,
+ config: UNetMidBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let vs_attns = vs.pp("attentions");
+ let resnet_groups = config
+ .resnet_groups
+ .unwrap_or_else(|| usize::min(in_channels / 4, 32));
+ let resnet_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ groups: resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
+ let attn_cfg = AttentionBlockConfig {
+ num_head_channels: config.attn_num_head_channels,
+ num_groups: resnet_groups,
+ rescale_output_factor: config.output_scale_factor,
+ eps: config.resnet_eps,
+ };
+ let mut attn_resnets = vec![];
+ for index in 0..config.num_layers {
+ let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
+ let resnet = ResnetBlock2D::new(
+ vs_resnets.pp(&(index + 1).to_string()),
+ in_channels,
+ resnet_cfg,
+ )?;
+ attn_resnets.push((attn, resnet))
+ }
+ Ok(Self {
+ resnet,
+ attn_resnets,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
+ let mut xs = self.resnet.forward(xs, temb)?;
+ for (attn, resnet) in self.attn_resnets.iter() {
+ xs = resnet.forward(&attn.forward(&xs)?, temb)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UNetMidBlock2DCrossAttnConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: Option<usize>,
+ pub attn_num_head_channels: usize,
+ // attention_type "default"
+ pub output_scale_factor: f64,
+ pub cross_attn_dim: usize,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for UNetMidBlock2DCrossAttnConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: Some(32),
+ attn_num_head_channels: 1,
+ output_scale_factor: 1.,
+ cross_attn_dim: 1280,
+ sliced_attention_size: None, // Sliced attention disabled
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UNetMidBlock2DCrossAttn {
+ resnet: ResnetBlock2D,
+ attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
+ pub config: UNetMidBlock2DCrossAttnConfig,
+}
+
+impl UNetMidBlock2DCrossAttn {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ temb_channels: Option<usize>,
+ config: UNetMidBlock2DCrossAttnConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let vs_attns = vs.pp("attentions");
+ let resnet_groups = config
+ .resnet_groups
+ .unwrap_or_else(|| usize::min(in_channels / 4, 32));
+ let resnet_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ groups: resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
+ let n_heads = config.attn_num_head_channels;
+ let attn_cfg = SpatialTransformerConfig {
+ depth: 1,
+ num_groups: resnet_groups,
+ context_dim: Some(config.cross_attn_dim),
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let mut attn_resnets = vec![];
+ for index in 0..config.num_layers {
+ let attn = SpatialTransformer::new(
+ vs_attns.pp(&index.to_string()),
+ in_channels,
+ n_heads,
+ in_channels / n_heads,
+ attn_cfg,
+ )?;
+ let resnet = ResnetBlock2D::new(
+ vs_resnets.pp(&(index + 1).to_string()),
+ in_channels,
+ resnet_cfg,
+ )?;
+ attn_resnets.push((attn, resnet))
+ }
+ Ok(Self {
+ resnet,
+ attn_resnets,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ temb: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut xs = self.resnet.forward(xs, temb)?;
+ for (attn, resnet) in self.attn_resnets.iter() {
+ xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DownBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ // resnet_time_scale_shift: "default"
+ // resnet_act_fn: "swish"
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_downsample: bool,
+ pub downsample_padding: usize,
+}
+
+impl Default for DownBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_downsample: true,
+ downsample_padding: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct DownBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ downsampler: Option<Downsample2D>,
+ pub config: DownBlock2DConfig,
+}
+
+impl DownBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: DownBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let resnet_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ eps: config.resnet_eps,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnets = (0..config.num_layers)
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let downsampler = if config.add_downsample {
+ let downsampler = Downsample2D::new(
+ vs.pp("downsamplers").pp("0"),
+ out_channels,
+ true,
+ out_channels,
+ config.downsample_padding,
+ )?;
+ Some(downsampler)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ downsampler,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
+ let mut xs = xs.clone();
+ let mut output_states = vec![];
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, temb)?;
+ output_states.push(xs.clone());
+ }
+ let xs = match &self.downsampler {
+ Some(downsampler) => {
+ let xs = downsampler.forward(&xs)?;
+ output_states.push(xs.clone());
+ xs
+ }
+ None => xs,
+ };
+ Ok((xs, output_states))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct CrossAttnDownBlock2DConfig {
+ pub downblock: DownBlock2DConfig,
+ pub attn_num_head_channels: usize,
+ pub cross_attention_dim: usize,
+ // attention_type: "default"
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for CrossAttnDownBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ downblock: Default::default(),
+ attn_num_head_channels: 1,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct CrossAttnDownBlock2D {
+ downblock: DownBlock2D,
+ attentions: Vec<SpatialTransformer>,
+ pub config: CrossAttnDownBlock2DConfig,
+}
+
+impl CrossAttnDownBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: CrossAttnDownBlock2DConfig,
+ ) -> Result<Self> {
+ let downblock = DownBlock2D::new(
+ vs.clone(),
+ in_channels,
+ out_channels,
+ temb_channels,
+ config.downblock,
+ )?;
+ let n_heads = config.attn_num_head_channels;
+ let cfg = SpatialTransformerConfig {
+ depth: 1,
+ context_dim: Some(config.cross_attention_dim),
+ num_groups: config.downblock.resnet_groups,
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let vs_attn = vs.pp("attentions");
+ let attentions = (0..config.downblock.num_layers)
+ .map(|i| {
+ SpatialTransformer::new(
+ vs_attn.pp(&i.to_string()),
+ out_channels,
+ n_heads,
+ out_channels / n_heads,
+ cfg,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ downblock,
+ attentions,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ temb: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<(Tensor, Vec<Tensor>)> {
+ let mut output_states = vec![];
+ let mut xs = xs.clone();
+ for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
+ xs = resnet.forward(&xs, temb)?;
+ xs = attn.forward(&xs, encoder_hidden_states)?;
+ output_states.push(xs.clone());
+ }
+ let xs = match &self.downblock.downsampler {
+ Some(downsampler) => {
+ let xs = downsampler.forward(&xs)?;
+ output_states.push(xs.clone());
+ xs
+ }
+ None => xs,
+ };
+ Ok((xs, output_states))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UpBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ // resnet_time_scale_shift: "default"
+ // resnet_act_fn: "swish"
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_upsample: bool,
+}
+
+impl Default for UpBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_upsample: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UpBlock2D {
+ pub resnets: Vec<ResnetBlock2D>,
+ upsampler: Option<Upsample2D>,
+ pub config: UpBlock2DConfig,
+}
+
+impl UpBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ prev_output_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: UpBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let resnet_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ temb_channels,
+ eps: config.resnet_eps,
+ output_scale_factor: config.output_scale_factor,
+ ..Default::default()
+ };
+ let resnets = (0..config.num_layers)
+ .map(|i| {
+ let res_skip_channels = if i == config.num_layers - 1 {
+ in_channels
+ } else {
+ out_channels
+ };
+ let resnet_in_channels = if i == 0 {
+ prev_output_channels
+ } else {
+ out_channels
+ };
+ let in_channels = resnet_in_channels + res_skip_channels;
+ ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let upsampler = if config.add_upsample {
+ let upsampler =
+ Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
+ Some(upsampler)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ upsampler,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ res_xs: &[Tensor],
+ temb: Option<&Tensor>,
+ upsample_size: Option<(usize, usize)>,
+ ) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for (index, resnet) in self.resnets.iter().enumerate() {
+ xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = resnet.forward(&xs, temb)?;
+ }
+ match &self.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, upsample_size),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct CrossAttnUpBlock2DConfig {
+ pub upblock: UpBlock2DConfig,
+ pub attn_num_head_channels: usize,
+ pub cross_attention_dim: usize,
+ // attention_type: "default"
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for CrossAttnUpBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ upblock: Default::default(),
+ attn_num_head_channels: 1,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct CrossAttnUpBlock2D {
+ pub upblock: UpBlock2D,
+ pub attentions: Vec<SpatialTransformer>,
+ pub config: CrossAttnUpBlock2DConfig,
+}
+
+impl CrossAttnUpBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ prev_output_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: CrossAttnUpBlock2DConfig,
+ ) -> Result<Self> {
+ let upblock = UpBlock2D::new(
+ vs.clone(),
+ in_channels,
+ prev_output_channels,
+ out_channels,
+ temb_channels,
+ config.upblock,
+ )?;
+ let n_heads = config.attn_num_head_channels;
+ let cfg = SpatialTransformerConfig {
+ depth: 1,
+ context_dim: Some(config.cross_attention_dim),
+ num_groups: config.upblock.resnet_groups,
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let vs_attn = vs.pp("attentions");
+ let attentions = (0..config.upblock.num_layers)
+ .map(|i| {
+ SpatialTransformer::new(
+ vs_attn.pp(&i.to_string()),
+ out_channels,
+ n_heads,
+ out_channels / n_heads,
+ cfg,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ upblock,
+ attentions,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ res_xs: &[Tensor],
+ temb: Option<&Tensor>,
+ upsample_size: Option<(usize, usize)>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for (index, resnet) in self.upblock.resnets.iter().enumerate() {
+ xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = resnet.forward(&xs, temb)?;
+ xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
+ }
+ match &self.upblock.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, upsample_size),
+ None => Ok(xs),
+ }
+ }
+}