summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-15 15:06:21 +0200
committerGitHub <noreply@github.com>2023-09-15 14:06:21 +0100
commit30be5b6660ca86f8ddd2cca88890cf4e40e45e12 (patch)
tree51d2f13e6a3b70d9c85b1db0c79f59ccbcc12ebc /candle-transformers/src
parent107d3d953070f7817b3aaac9ed8ca0fed7030d01 (diff)
downloadcandle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.tar.gz
candle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.tar.bz2
candle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.zip
Replication pad (#861)
* Add the embed mapper convolutions. * Add the replication pad layer. * Use the replication-pad op. * Tweak a todo.
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/wuerstchen/paella_vq.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs
index 6301b7a1..6589a07d 100644
--- a/candle-transformers/src/models/wuerstchen/paella_vq.rs
+++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs
@@ -49,7 +49,7 @@ impl Module for MixingResidualBlock {
.apply(&self.norm1)?
.permute((0, 3, 1, 2))?
.affine(1. + mods[0] as f64, mods[1] as f64)?;
- // TODO: Add the ReplicationPad2d
+ let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?;
let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
let x_temp = xs
.permute((0, 2, 3, 1))?
@@ -88,10 +88,10 @@ impl PaellaVQ {
}
xs.apply(&self.down_blocks_conv)?
.apply(&self.down_blocks_bn)
- // TODO: quantizer
}
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ // TODO: quantizer if we want to support `force_not_quantize=False`.
let mut xs = xs.apply(&self.up_blocks_conv)?;
for up_block in self.up_blocks.iter() {
xs = xs.apply(&up_block.0)?;