summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/wuerstchen/paella_vq.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs
index 6301b7a1..6589a07d 100644
--- a/candle-transformers/src/models/wuerstchen/paella_vq.rs
+++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs
@@ -49,7 +49,7 @@ impl Module for MixingResidualBlock {
.apply(&self.norm1)?
.permute((0, 3, 1, 2))?
.affine(1. + mods[0] as f64, mods[1] as f64)?;
- // TODO: Add the ReplicationPad2d
+ let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?;
let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
let x_temp = xs
.permute((0, 2, 3, 1))?
@@ -88,10 +88,10 @@ impl PaellaVQ {
}
xs.apply(&self.down_blocks_conv)?
.apply(&self.down_blocks_bn)
- // TODO: quantizer
}
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ // TODO: quantizer if we want to support `force_not_quantize=False`.
let mut xs = xs.apply(&self.up_blocks_conv)?;
for up_block in self.up_blocks.iter() {
xs = xs.apply(&up_block.0)?;