summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/wuerstchen
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-14 22:16:31 +0200
committerGitHub <noreply@github.com>2023-09-14 22:16:31 +0200
commit91ec546febee4c6333cd65d95e8fd09e94499024 (patch)
treed0f9422bb447018412744243076c0ade935fcb22 /candle-transformers/src/models/wuerstchen
parent0a647875ec7f9861ede7fa54713af50b4207ffb7 (diff)
downloadcandle-91ec546febee4c6333cd65d95e8fd09e94499024.tar.gz
candle-91ec546febee4c6333cd65d95e8fd09e94499024.tar.bz2
candle-91ec546febee4c6333cd65d95e8fd09e94499024.zip
More DiffNeXt. (#847)
* More DiffNeXt. * Down blocks.
Diffstat (limited to 'candle-transformers/src/models/wuerstchen')
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs146
1 files changed, 144 insertions, 2 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
index 8e5099f6..5e49437c 100644
--- a/candle-transformers/src/models/wuerstchen/diffnext.rs
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -1,5 +1,5 @@
#![allow(unused)]
-use super::common::{GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm};
+use super::common::{AttnBlock, GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
@@ -52,24 +52,54 @@ impl ResBlockStageB {
}
#[derive(Debug)]
+struct SubBlock {
+ res_block: ResBlockStageB,
+ ts_block: TimestepBlock,
+ attn_block: Option<AttnBlock>,
+}
+
+#[derive(Debug)]
+struct DownBlock {
+ layer_norm: Option<WLayerNorm>,
+ conv: Option<candle_nn::Conv2d>,
+ sub_blocks: Vec<SubBlock>,
+}
+
+#[derive(Debug)]
+struct UpBlock {
+ sub_blocks: Vec<SubBlock>,
+ layer_norm: Option<WLayerNorm>,
+ conv: Option<candle_nn::Conv2d>,
+}
+
+#[derive(Debug)]
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<candle_nn::Conv2d>,
seq_norm: candle_nn::LayerNorm,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
+ down_blocks: Vec<DownBlock>,
+ up_blocks: Vec<UpBlock>,
+ clf_ln: WLayerNorm,
+ clf_conv: candle_nn::Conv2d,
+ c_r: usize,
}
impl WDiffNeXt {
pub fn new(
c_in: usize,
c_out: usize,
- vb: VarBuilder,
+ c_r: usize,
c_cond: usize,
clip_embd: usize,
patch_size: usize,
+ vb: VarBuilder,
) -> Result<Self> {
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
+ const BLOCKS: [usize; 4] = [4, 4, 14, 4];
+ const NHEAD: [usize; 4] = [0, 10, 20, 20];
+ const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
let effnet_mappers = vec![];
@@ -85,12 +115,124 @@ impl WDiffNeXt {
Default::default(),
vb.pp("embedding.2"),
)?;
+
+ let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
+ for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
+ let vb = vb.pp("down_blocks").pp(i);
+ let (layer_norm, conv, start_layer_i) = if i > 0 {
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(0))?;
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(1))?;
+ (Some(layer_norm), Some(conv), 2)
+ } else {
+ (None, None, 0)
+ };
+ let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
+ let mut layer_i = start_layer_i;
+ for j in 0..BLOCKS[i] {
+ let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
+ let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
+ layer_i += 1;
+ let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
+ layer_i += 1;
+ let attn_block = if j == 0 {
+ None
+ } else {
+ let attn_block =
+ AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
+ layer_i += 1;
+ Some(attn_block)
+ };
+ let sub_block = SubBlock {
+ res_block,
+ ts_block,
+ attn_block,
+ };
+ sub_blocks.push(sub_block)
+ }
+ let down_block = DownBlock {
+ layer_norm,
+ conv,
+ sub_blocks,
+ };
+ down_blocks.push(down_block)
+ }
+
+ // TODO: populate.
+ let up_blocks = Vec::with_capacity(C_HIDDEN.len());
+
+ let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
+ let clf_conv = candle_nn::conv2d(
+ C_HIDDEN[0],
+ 2 * c_out * patch_size * patch_size,
+ 1,
+ Default::default(),
+ vb.pp("clf.1"),
+ )?;
Ok(Self {
clip_mapper,
effnet_mappers,
seq_norm,
embedding_conv,
embedding_ln,
+ down_blocks,
+ up_blocks,
+ clf_ln,
+ clf_conv,
+ c_r,
})
}
+
+ fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
+ const MAX_POSITIONS: usize = 10000;
+ let r = (r * MAX_POSITIONS as f64)?;
+ let half_dim = self.c_r / 2;
+ let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
+ let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
+ * -emb)?
+ .exp()?;
+ let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
+ let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
+ let emb = if self.c_r % 2 == 1 {
+ emb.pad_with_zeros(D::Minus1, 0, 1)?
+ } else {
+ emb
+ };
+ emb.to_dtype(r.dtype())
+ }
+
+ fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> {
+ clip.apply(&self.clip_mapper)?.apply(&self.seq_norm)
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ r: &Tensor,
+ effnet: &Tensor,
+ clip: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ const EPS: f64 = 1e-3;
+
+ let r_embed = self.gen_r_embedding(r)?;
+ let clip = match clip {
+ None => None,
+ Some(clip) => Some(self.gen_c_embeddings(clip)?),
+ };
+ let x_in = xs;
+
+ // TODO: pixel unshuffle.
+ let xs = xs.apply(&self.embedding_conv)?.apply(&self.embedding_ln)?;
+ // TODO: down blocks
+ let level_outputs = xs.clone();
+ // TODO: up blocks
+ let xs = level_outputs;
+ // TODO: pxel shuffle
+ let ab = xs.apply(&self.clf_ln)?.apply(&self.clf_conv)?.chunk(1, 2)?;
+ let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
+ (x_in - &ab[0])? / b
+ }
}