diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-11 15:53:05 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-11 14:53:05 +0100 |
commit | 662db45fc3728c4d7256edee76b3dd99877cb53f (patch) | |
tree | c93d3c8ca683cf3819beb9c52d3f2970c510885b | |
parent | 906c0f3eb50d9c128e59718f27125babcea9986e (diff) | |
download | candle-662db45fc3728c4d7256edee76b3dd99877cb53f.tar.gz candle-662db45fc3728c4d7256edee76b3dd99877cb53f.tar.bz2 candle-662db45fc3728c4d7256edee76b3dd99877cb53f.zip |
Use zero padding in conv1d and conv2d (same as pytorch). (#408)
-rw-r--r-- | candle-core/src/cpu_backend.rs | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 86f14e32..dcf4ed94 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1057,9 +1057,11 @@ impl<'a> Map2 for Conv1D<'a> { let dst_idx = dst_idx + b_idx * p.c_out * l_out; for dst_l in 0..l_out { let dst_idx = dst_idx + dst_l; - let src_l = (p.stride * dst_l + offset) - .saturating_sub(p.padding) - .min(p.l_in - 1); + let src_l = p.stride * dst_l + offset; + if src_l < p.padding || src_l >= p.padding + p.l_in { + continue; + } + let src_l = src_l - p.padding; let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..]; assert!(inp_cont.len() >= p.c_in); assert!(k_cont.len() >= p.c_in); @@ -1132,14 +1134,18 @@ impl<'a> Map2 for Conv2D<'a> { let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w; for dst_h in 0..out_h { let dst_idx = dst_idx + dst_h * out_w; - let src_h = (p.stride * dst_h + offset_h) - .saturating_sub(p.padding) - .min(p.i_h - 1); + let src_h = p.stride * dst_h + offset_h; + if src_h < p.padding || src_h >= p.i_h + p.padding { + continue; + } + let src_h = src_h - p.padding; for dst_w in 0..out_w { let dst_idx = dst_idx + dst_w; - let src_w = (p.stride * dst_w + offset_w) - .saturating_sub(p.padding) - .min(p.i_w - 1); + let src_w = p.stride * dst_w + offset_w; + if src_w < p.padding || src_w >= p.i_w + p.padding { + continue; + } + let src_w = src_w - p.padding; let inp_cont = &inp_cont [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..]; assert!(inp_cont.len() >= p.c_in); |