summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/wuerstchen
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-17 15:59:27 +0100
committerGitHub <noreply@github.com>2023-09-17 15:59:27 +0100
commit06cc329e715cbb820343f9849a4a45c818cb8c5e (patch)
treebbfd12ef3467c88aa85847496a3d305f408eba20 /candle-transformers/src/models/wuerstchen
parent5f83c13f17a7b16955c9b649424aca276d5e930d (diff)
downloadcandle-06cc329e715cbb820343f9849a4a45c818cb8c5e.tar.gz
candle-06cc329e715cbb820343f9849a4a45c818cb8c5e.tar.bz2
candle-06cc329e715cbb820343f9849a4a45c818cb8c5e.zip
Remove the parameters for the Wuerstchen layer-norm. (#879)
* Remove the parameters for the Wuerstchen layer-norm. * Fixes. * More fixes (including conv-transpose2d. * More fixes. * Again more fixes.
Diffstat (limited to 'candle-transformers/src/models/wuerstchen')
-rw-r--r--candle-transformers/src/models/wuerstchen/common.rs39
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs23
-rw-r--r--candle-transformers/src/models/wuerstchen/paella_vq.rs24
-rw-r--r--candle-transformers/src/models/wuerstchen/prior.rs2
4 files changed, 44 insertions, 44 deletions
diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs
index 10e7b19f..ee318d27 100644
--- a/candle-transformers/src/models/wuerstchen/common.rs
+++ b/candle-transformers/src/models/wuerstchen/common.rs
@@ -1,28 +1,35 @@
-use candle::{Module, Result, Tensor, D};
+use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
#[derive(Debug)]
pub struct WLayerNorm {
- inner: candle_nn::LayerNorm,
+ eps: f64,
}
impl WLayerNorm {
- pub fn new(size: usize, vb: VarBuilder) -> Result<Self> {
- let cfg = candle_nn::layer_norm::LayerNormConfig {
- eps: 1e-6,
- remove_mean: true,
- affine: false,
- };
- let inner = candle_nn::layer_norm(size, cfg, vb)?;
- Ok(Self { inner })
+ pub fn new(_size: usize) -> Result<Self> {
+ Ok(Self { eps: 1e-6 })
}
}
impl Module for WLayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- xs.permute((0, 2, 3, 1))?
- .apply(&self.inner)?
+ let xs = xs.permute((0, 2, 3, 1))?;
+
+ let x_dtype = xs.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+
+ let hidden_size = xs.dim(D::Minus1)?;
+ let xs = xs.to_dtype(internal_dtype)?;
+ let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ let xs = xs.broadcast_sub(&mean_x)?;
+ let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
+ .to_dtype(x_dtype)?
.permute((0, 3, 1, 2))
}
}
@@ -57,8 +64,8 @@ pub struct GlobalResponseNorm {
impl GlobalResponseNorm {
pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
- let gamma = vb.get((1, 1, 1, 1, dim), "gamma")?;
- let beta = vb.get((1, 1, 1, 1, dim), "beta")?;
+ let gamma = vb.get((1, 1, 1, dim), "gamma")?;
+ let beta = vb.get((1, 1, 1, dim), "beta")?;
Ok(Self { gamma, beta })
}
}
@@ -92,7 +99,7 @@ impl ResBlock {
..Default::default()
};
let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
- let norm = WLayerNorm::new(c, vb.pp("norm"))?;
+ let norm = WLayerNorm::new(c)?;
let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
@@ -141,7 +148,7 @@ impl AttnBlock {
self_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
- let norm = WLayerNorm::new(c, vb.pp("norm"))?;
+ let norm = WLayerNorm::new(c)?;
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
Ok(Self {
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
index 001b35d7..70e4ba34 100644
--- a/candle-transformers/src/models/wuerstchen/diffnext.rs
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -19,7 +19,7 @@ impl ResBlockStageB {
..Default::default()
};
let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
- let norm = WLayerNorm::new(c, vb.pp("norm"))?;
+ let norm = WLayerNorm::new(c)?;
let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
@@ -75,7 +75,7 @@ struct UpBlock {
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
- seq_norm: candle_nn::LayerNorm,
+ seq_norm: WLayerNorm,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
down_blocks: Vec<DownBlock>,
@@ -98,7 +98,7 @@ impl WDiffNeXt {
) -> Result<Self> {
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
- const NHEAD: [usize; 4] = [0, 10, 20, 20];
+ const NHEAD: [usize; 4] = [1, 10, 20, 20];
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
const EFFNET_EMBD: usize = 16;
@@ -133,24 +133,21 @@ impl WDiffNeXt {
};
effnet_mappers.push(c)
}
- let cfg = candle_nn::layer_norm::LayerNormConfig {
- ..Default::default()
- };
- let seq_norm = candle_nn::layer_norm(c_cond, cfg, vb.pp("seq_norm"))?;
- let embedding_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("embedding.1"))?;
+ let seq_norm = WLayerNorm::new(c_cond)?;
+ let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
let embedding_conv = candle_nn::conv2d(
c_in * patch_size * patch_size,
- C_HIDDEN[1],
+ C_HIDDEN[0],
1,
Default::default(),
- vb.pp("embedding.2"),
+ vb.pp("embedding.1"),
)?;
let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
let vb = vb.pp("down_blocks").pp(i);
let (layer_norm, conv, start_layer_i) = if i > 0 {
- let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(0))?;
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
@@ -223,7 +220,7 @@ impl WDiffNeXt {
sub_blocks.push(sub_block)
}
let (layer_norm, conv) = if i > 0 {
- let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
layer_i += 1;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
@@ -242,7 +239,7 @@ impl WDiffNeXt {
up_blocks.push(up_block)
}
- let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
+ let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
let clf_conv = candle_nn::conv2d(
C_HIDDEN[0],
2 * c_out * patch_size * patch_size,
diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs
index a60f8e8a..faf2d2b4 100644
--- a/candle-transformers/src/models/wuerstchen/paella_vq.rs
+++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs
@@ -1,11 +1,12 @@
+use super::common::WLayerNorm;
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
#[derive(Debug)]
pub struct MixingResidualBlock {
- norm1: candle_nn::LayerNorm,
+ norm1: WLayerNorm,
depthwise_conv: candle_nn::Conv2d,
- norm2: candle_nn::LayerNorm,
+ norm2: WLayerNorm,
channelwise_lin1: candle_nn::Linear,
channelwise_lin2: candle_nn::Linear,
gammas: Vec<f32>,
@@ -13,13 +14,8 @@ pub struct MixingResidualBlock {
impl MixingResidualBlock {
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
- let cfg = candle_nn::LayerNormConfig {
- affine: false,
- eps: 1e-6,
- remove_mean: true,
- };
- let norm1 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
- let norm2 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
+ let norm1 = WLayerNorm::new(inp)?;
+ let norm2 = WLayerNorm::new(inp)?;
let cfg = candle_nn::Conv2dConfig {
groups: inp,
..Default::default()
@@ -120,15 +116,15 @@ impl PaellaVQ {
d_idx += 1;
down_blocks.push((conv_block, res_block))
}
+ let vb_d = vb_d.pp(d_idx);
let down_blocks_conv = candle_nn::conv2d_no_bias(
C_LEVELS[1],
LATENT_CHANNELS,
1,
Default::default(),
- vb_d.pp(d_idx),
+ vb_d.pp(0),
)?;
- d_idx += 1;
- let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(d_idx))?;
+ let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(1))?;
let mut up_blocks = Vec::new();
let vb_u = vb.pp("up_blocks");
@@ -138,7 +134,7 @@ impl PaellaVQ {
C_LEVELS[1],
1,
Default::default(),
- vb_u.pp(u_idx),
+ vb_u.pp(u_idx).pp(0),
)?;
u_idx += 1;
for (i, &c_level) in C_LEVELS.iter().rev().enumerate() {
@@ -157,7 +153,7 @@ impl PaellaVQ {
};
let block = candle_nn::conv_transpose2d_no_bias(
c_level,
- C_LEVELS[i - 1],
+ C_LEVELS[C_LEVELS.len() - i - 2],
4,
cfg,
vb_u.pp(u_idx),
diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs
index a9e3e793..93385a32 100644
--- a/candle-transformers/src/models/wuerstchen/prior.rs
+++ b/candle-transformers/src/models/wuerstchen/prior.rs
@@ -33,7 +33,7 @@ impl WPrior {
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
- let out_ln = super::common::WLayerNorm::new(c, vb.pp("out.0"))?;
+ let out_ln = super::common::WLayerNorm::new(c)?;
let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
let mut blocks = Vec::with_capacity(depth);
for index in 0..depth {