diff options
Diffstat (limited to 'candle-transformers/src/models/wuerstchen/diffnext.rs')
-rw-r--r-- | candle-transformers/src/models/wuerstchen/diffnext.rs | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs new file mode 100644 index 00000000..82c973a1 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -0,0 +1,55 @@ +#![allow(unused)] +use super::common::{GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm}; +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +pub struct ResBlockStageB { + depthwise: candle_nn::Conv2d, + norm: WLayerNorm, + channelwise_lin1: candle_nn::Linear, + channelwise_grn: GlobalResponseNorm, + channelwise_lin2: candle_nn::Linear, +} + +impl ResBlockStageB { + pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + groups: c, + padding: ksize / 2, + ..Default::default() + }; + let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?; + let norm = WLayerNorm::new(c, vb.pp("norm"))?; + let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?; + let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?; + let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; + Ok(Self { + depthwise, + norm, + channelwise_lin1, + channelwise_grn, + channelwise_lin2, + }) + } + + pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> { + let x_res = xs; + let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?; + let xs = match x_skip { + None => xs.clone(), + Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?, + }; + let xs = xs + .permute((0, 2, 3, 1))? + .apply(&self.channelwise_lin1)? + .gelu()? + .apply(&self.channelwise_grn)? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_res + } +} + +#[derive(Debug)] +pub struct WDiffNeXt {} |