diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-17 15:59:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-17 15:59:27 +0100 |
commit | 06cc329e715cbb820343f9849a4a45c818cb8c5e (patch) | |
tree | bbfd12ef3467c88aa85847496a3d305f408eba20 /candle-transformers/src/models/wuerstchen | |
parent | 5f83c13f17a7b16955c9b649424aca276d5e930d (diff) | |
download | candle-06cc329e715cbb820343f9849a4a45c818cb8c5e.tar.gz candle-06cc329e715cbb820343f9849a4a45c818cb8c5e.tar.bz2 candle-06cc329e715cbb820343f9849a4a45c818cb8c5e.zip |
Remove the parameters for the Wuerstchen layer-norm. (#879)
* Remove the parameters for the Wuerstchen layer-norm.
* Fixes.
* More fixes (including conv-transpose2d.
* More fixes.
* Again more fixes.
Diffstat (limited to 'candle-transformers/src/models/wuerstchen')
4 files changed, 44 insertions, 44 deletions
diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs index 10e7b19f..ee318d27 100644 --- a/candle-transformers/src/models/wuerstchen/common.rs +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -1,28 +1,35 @@ -use candle::{Module, Result, Tensor, D}; +use candle::{DType, Module, Result, Tensor, D}; use candle_nn::VarBuilder; // https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22 #[derive(Debug)] pub struct WLayerNorm { - inner: candle_nn::LayerNorm, + eps: f64, } impl WLayerNorm { - pub fn new(size: usize, vb: VarBuilder) -> Result<Self> { - let cfg = candle_nn::layer_norm::LayerNormConfig { - eps: 1e-6, - remove_mean: true, - affine: false, - }; - let inner = candle_nn::layer_norm(size, cfg, vb)?; - Ok(Self { inner }) + pub fn new(_size: usize) -> Result<Self> { + Ok(Self { eps: 1e-6 }) } } impl Module for WLayerNorm { fn forward(&self, xs: &Tensor) -> Result<Tensor> { - xs.permute((0, 2, 3, 1))? - .apply(&self.inner)? + let xs = xs.permute((0, 2, 3, 1))?; + + let x_dtype = xs.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + + let hidden_size = xs.dim(D::Minus1)?; + let xs = xs.to_dtype(internal_dtype)?; + let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let xs = xs.broadcast_sub(&mean_x)?; + let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype)? .permute((0, 3, 1, 2)) } } @@ -57,8 +64,8 @@ pub struct GlobalResponseNorm { impl GlobalResponseNorm { pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> { - let gamma = vb.get((1, 1, 1, 1, dim), "gamma")?; - let beta = vb.get((1, 1, 1, 1, dim), "beta")?; + let gamma = vb.get((1, 1, 1, dim), "gamma")?; + let beta = vb.get((1, 1, 1, dim), "beta")?; Ok(Self { gamma, beta }) } } @@ -92,7 +99,7 @@ impl ResBlock { ..Default::default() }; let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?; - let norm = WLayerNorm::new(c, vb.pp("norm"))?; + let norm = WLayerNorm::new(c)?; let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?; let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?; let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; @@ -141,7 +148,7 @@ impl AttnBlock { self_attn: bool, vb: VarBuilder, ) -> Result<Self> { - let norm = WLayerNorm::new(c, vb.pp("norm"))?; + let norm = WLayerNorm::new(c)?; let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?; let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?; Ok(Self { diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 001b35d7..70e4ba34 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -19,7 +19,7 @@ impl ResBlockStageB { ..Default::default() }; let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?; - let norm = WLayerNorm::new(c, vb.pp("norm"))?; + let norm = WLayerNorm::new(c)?; let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?; let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?; let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; @@ -75,7 +75,7 @@ struct UpBlock { pub struct WDiffNeXt { clip_mapper: candle_nn::Linear, effnet_mappers: Vec<Option<candle_nn::Conv2d>>, - seq_norm: candle_nn::LayerNorm, + seq_norm: WLayerNorm, embedding_conv: candle_nn::Conv2d, embedding_ln: WLayerNorm, down_blocks: Vec<DownBlock>, @@ -98,7 +98,7 @@ impl WDiffNeXt { ) -> Result<Self> { const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280]; const BLOCKS: [usize; 4] = [4, 4, 14, 4]; - const NHEAD: [usize; 4] = [0, 10, 20, 20]; + const NHEAD: [usize; 4] = [1, 10, 20, 20]; const INJECT_EFFNET: [bool; 4] = [false, true, true, true]; const EFFNET_EMBD: usize = 16; @@ -133,24 +133,21 @@ impl WDiffNeXt { }; effnet_mappers.push(c) } - let cfg = candle_nn::layer_norm::LayerNormConfig { - ..Default::default() - }; - let seq_norm = candle_nn::layer_norm(c_cond, cfg, vb.pp("seq_norm"))?; - let embedding_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("embedding.1"))?; + let seq_norm = WLayerNorm::new(c_cond)?; + let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?; let embedding_conv = candle_nn::conv2d( c_in * patch_size * patch_size, - C_HIDDEN[1], + C_HIDDEN[0], 1, Default::default(), - vb.pp("embedding.2"), + vb.pp("embedding.1"), )?; let mut down_blocks = Vec::with_capacity(C_HIDDEN.len()); for (i, &c_hidden) in C_HIDDEN.iter().enumerate() { let vb = vb.pp("down_blocks").pp(i); let (layer_norm, conv, start_layer_i) = if i > 0 { - let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(0))?; + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; let cfg = candle_nn::Conv2dConfig { stride: 2, ..Default::default() @@ -223,7 +220,7 @@ impl WDiffNeXt { sub_blocks.push(sub_block) } let (layer_norm, conv) = if i > 0 { - let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?; + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; layer_i += 1; let cfg = candle_nn::Conv2dConfig { stride: 2, @@ -242,7 +239,7 @@ impl WDiffNeXt { up_blocks.push(up_block) } - let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?; + let clf_ln = WLayerNorm::new(C_HIDDEN[0])?; let clf_conv = candle_nn::conv2d( C_HIDDEN[0], 2 * c_out * patch_size * patch_size, diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index a60f8e8a..faf2d2b4 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -1,11 +1,12 @@ +use super::common::WLayerNorm; use candle::{Module, Result, Tensor}; use candle_nn::VarBuilder; #[derive(Debug)] pub struct MixingResidualBlock { - norm1: candle_nn::LayerNorm, + norm1: WLayerNorm, depthwise_conv: candle_nn::Conv2d, - norm2: candle_nn::LayerNorm, + norm2: WLayerNorm, channelwise_lin1: candle_nn::Linear, channelwise_lin2: candle_nn::Linear, gammas: Vec<f32>, @@ -13,13 +14,8 @@ pub struct MixingResidualBlock { impl MixingResidualBlock { pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> { - let cfg = candle_nn::LayerNormConfig { - affine: false, - eps: 1e-6, - remove_mean: true, - }; - let norm1 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?; - let norm2 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?; + let norm1 = WLayerNorm::new(inp)?; + let norm2 = WLayerNorm::new(inp)?; let cfg = candle_nn::Conv2dConfig { groups: inp, ..Default::default() @@ -120,15 +116,15 @@ impl PaellaVQ { d_idx += 1; down_blocks.push((conv_block, res_block)) } + let vb_d = vb_d.pp(d_idx); let down_blocks_conv = candle_nn::conv2d_no_bias( C_LEVELS[1], LATENT_CHANNELS, 1, Default::default(), - vb_d.pp(d_idx), + vb_d.pp(0), )?; - d_idx += 1; - let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(d_idx))?; + let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(1))?; let mut up_blocks = Vec::new(); let vb_u = vb.pp("up_blocks"); @@ -138,7 +134,7 @@ impl PaellaVQ { C_LEVELS[1], 1, Default::default(), - vb_u.pp(u_idx), + vb_u.pp(u_idx).pp(0), )?; u_idx += 1; for (i, &c_level) in C_LEVELS.iter().rev().enumerate() { @@ -157,7 +153,7 @@ impl PaellaVQ { }; let block = candle_nn::conv_transpose2d_no_bias( c_level, - C_LEVELS[i - 1], + C_LEVELS[C_LEVELS.len() - i - 2], 4, cfg, vb_u.pp(u_idx), diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs index a9e3e793..93385a32 100644 --- a/candle-transformers/src/models/wuerstchen/prior.rs +++ b/candle-transformers/src/models/wuerstchen/prior.rs @@ -33,7 +33,7 @@ impl WPrior { let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?; let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?; let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?; - let out_ln = super::common::WLayerNorm::new(c, vb.pp("out.0"))?; + let out_ln = super::common::WLayerNorm::new(c)?; let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?; let mut blocks = Vec::with_capacity(depth); for index in 0..depth { |