diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-15 15:06:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-15 14:06:21 +0100 |
commit | 30be5b6660ca86f8ddd2cca88890cf4e40e45e12 (patch) | |
tree | 51d2f13e6a3b70d9c85b1db0c79f59ccbcc12ebc /candle-transformers/src | |
parent | 107d3d953070f7817b3aaac9ed8ca0fed7030d01 (diff) | |
download | candle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.tar.gz candle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.tar.bz2 candle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.zip |
Replication pad (#861)
* Add the embed mapper convolutions.
* Add the replication pad layer.
* Use the replication-pad op.
* Tweak a todo.
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/models/wuerstchen/paella_vq.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index 6301b7a1..6589a07d 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -49,7 +49,7 @@ impl Module for MixingResidualBlock { .apply(&self.norm1)? .permute((0, 3, 1, 2))? .affine(1. + mods[0] as f64, mods[1] as f64)?; - // TODO: Add the ReplicationPad2d + let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?; let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?; let x_temp = xs .permute((0, 2, 3, 1))? @@ -88,10 +88,10 @@ impl PaellaVQ { } xs.apply(&self.down_blocks_conv)? .apply(&self.down_blocks_bn) - // TODO: quantizer } pub fn decode(&self, xs: &Tensor) -> Result<Tensor> { + // TODO: quantizer if we want to support `force_not_quantize=False`. let mut xs = xs.apply(&self.up_blocks_conv)?; for up_block in self.up_blocks.iter() { xs = xs.apply(&up_block.0)?; |