summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-15 15:06:21 +0200
committerGitHub <noreply@github.com>2023-09-15 14:06:21 +0100
commit30be5b6660ca86f8ddd2cca88890cf4e40e45e12 (patch)
tree51d2f13e6a3b70d9c85b1db0c79f59ccbcc12ebc /candle-nn/src
parent107d3d953070f7817b3aaac9ed8ca0fed7030d01 (diff)
downloadcandle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.tar.gz
candle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.tar.bz2
candle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.zip
Replication pad (#861)
* Add the embed mapper convolutions. * Add the replication pad layer. * Use the replication-pad op. * Tweak a todo.
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/ops.rs15
1 files changed, 15 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 16b2e924..1256a076 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -213,3 +213,18 @@ pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
.permute((0, 1, 3, 5, 2, 4))?
.reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
}
+
+// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html
+pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
+ match pad {
+ 0 => Ok(xs.clone()),
+ 1 => {
+ let (_b_size, _c, h, w) = xs.dims4()?;
+ let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
+ let xs = Tensor::cat(&[&first, xs, &last], 3)?;
+ let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
+ Tensor::cat(&[&first, &xs, &last], 2)
+ }
+ n => candle::bail!("replication-pad with a size of {n} is not supported"),
+ }
+}