summaryrefslogtreecommitdiff
path: root/candle-nn/src/ops.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r--candle-nn/src/ops.rs15
1 files changed, 15 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 16b2e924..1256a076 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -213,3 +213,18 @@ pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
.permute((0, 1, 3, 5, 2, 4))?
.reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
}
+
+// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html
+pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
+ match pad {
+ 0 => Ok(xs.clone()),
+ 1 => {
+ let (_b_size, _c, h, w) = xs.dims4()?;
+ let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
+ let xs = Tensor::cat(&[&first, xs, &last], 3)?;
+ let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
+ Tensor::cat(&[&first, &xs, &last], 2)
+ }
+ n => candle::bail!("replication-pad with a size of {n} is not supported"),
+ }
+}