diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-15 15:06:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-15 14:06:21 +0100 |
commit | 30be5b6660ca86f8ddd2cca88890cf4e40e45e12 (patch) | |
tree | 51d2f13e6a3b70d9c85b1db0c79f59ccbcc12ebc /candle-nn/src | |
parent | 107d3d953070f7817b3aaac9ed8ca0fed7030d01 (diff) | |
download | candle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.tar.gz candle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.tar.bz2 candle-30be5b6660ca86f8ddd2cca88890cf4e40e45e12.zip |
Replication pad (#861)
* Add the embed mapper convolutions.
* Add the replication pad layer.
* Use the replication-pad op.
* Tweak a todo.
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/ops.rs | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 16b2e924..1256a076 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -213,3 +213,18 @@ pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> { .permute((0, 1, 3, 5, 2, 4))? .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor)) } + +// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html +pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> { + match pad { + 0 => Ok(xs.clone()), + 1 => { + let (_b_size, _c, h, w) = xs.dims4()?; + let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?); + let xs = Tensor::cat(&[&first, xs, &last], 3)?; + let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?); + Tensor::cat(&[&first, &xs, &last], 2) + } + n => candle::bail!("replication-pad with a size of {n} is not supported"), + } +} |