diff options
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r-- | candle-nn/src/ops.rs | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 16b2e924..1256a076 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -213,3 +213,18 @@ pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> { .permute((0, 1, 3, 5, 2, 4))? .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor)) } + +// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html +pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> { + match pad { + 0 => Ok(xs.clone()), + 1 => { + let (_b_size, _c, h, w) = xs.dims4()?; + let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?); + let xs = Tensor::cat(&[&first, xs, &last], 3)?; + let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?); + Tensor::cat(&[&first, &xs, &last], 2) + } + n => candle::bail!("replication-pad with a size of {n} is not supported"), + } +} |