diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-08 23:10:59 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-08 22:10:59 +0100 |
commit | cf965ecaa854a2f562a11c3885a5ab837757a5a7 (patch) | |
tree | 74992c1ad3f1c6aefbb7589352d52a78791936c8 | |
parent | b9864e135729283fab450abaff16982cc96552be (diff) | |
download | candle-cf965ecaa854a2f562a11c3885a5ab837757a5a7.tar.gz candle-cf965ecaa854a2f562a11c3885a5ab837757a5a7.tar.bz2 candle-cf965ecaa854a2f562a11c3885a5ab837757a5a7.zip |
Simplify the conv1d and conv2d code. (#352)
-rw-r--r-- | candle-core/src/cpu_backend.rs | 61 |
1 files changed, 27 insertions, 34 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 210231d8..05c1f4e8 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -997,7 +997,6 @@ impl<'a> Map2 for Conv1D<'a> { (0, inp_stride) // This value never gets used anyway }; let k_stride = k_l.stride(); - let k_over_2 = p.k_size / 2; let l_out = p.l_out(); let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1); let mut dst = vec![T::zero(); dst_elems]; @@ -1011,18 +1010,16 @@ impl<'a> Map2 for Conv1D<'a> { let dst_idx = dst_idx + dst_l; let mut d = T::zero(); for offset in 0..p.k_size { - let src_l_plus = p.stride * dst_l + offset + k_over_2 - p.padding; - // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] - if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in { - let src_l = src_l_plus - k_over_2; - for src_c_idx in 0..p.c_in { - let inp_idx = - inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; - let k_idx = dst_c_idx * k_stride[0] - + src_c_idx * k_stride[1] - + offset * k_stride[2]; - d += inp[inp_idx] * k[k_idx] - } + let src_l = (p.stride * dst_l + offset) + .saturating_sub(p.padding) + .min(p.l_in - 1); + for src_c_idx in 0..p.c_in { + let inp_idx = + inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; + let k_idx = dst_c_idx * k_stride[0] + + src_c_idx * k_stride[1] + + offset * k_stride[2]; + d += inp[inp_idx] * k[k_idx] } } dst[dst_idx] = d @@ -1064,27 +1061,23 @@ impl<'a> Map2 for Conv2D<'a> { let mut d = T::zero(); for offset_h in 0..p.k_h { // TODO: Handle the case where padding is larger than p.k_h / 2. - let src_h_plus = p.stride * dst_h + offset_h + p.k_h / 2 - p.padding; - if p.k_h / 2 <= src_h_plus && src_h_plus < p.k_h / 2 + p.i_h { - let src_h = src_h_plus - p.k_h / 2; - for offset_w in 0..p.k_w { - let src_w_plus = - p.stride * dst_w + offset_w + p.k_w / 2 - p.padding; - // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] - if p.k_w / 2 <= src_w_plus && src_w_plus < p.k_w / 2 + p.i_w { - let src_w = src_w_plus - p.k_w / 2; - for src_c_idx in 0..p.c_in { - let inp_idx = inp_idx - + src_c_idx * inp_stride[1] - + src_h * inp_stride[2] - + src_w * inp_stride[3]; - let k_idx = dst_c_idx * k_stride[0] - + src_c_idx * k_stride[1] - + offset_h * k_stride[2] - + offset_w * k_stride[3]; - d += inp[inp_idx] * k[k_idx] - } - } + let src_h = (p.stride * dst_h + offset_h) + .saturating_sub(p.padding) + .min(p.i_h - 1); + for offset_w in 0..p.k_w { + let src_w = (p.stride * dst_w + offset_w) + .saturating_sub(p.padding) + .min(p.i_w - 1); + for src_c_idx in 0..p.c_in { + let inp_idx = inp_idx + + src_c_idx * inp_stride[1] + + src_h * inp_stride[2] + + src_w * inp_stride[3]; + let k_idx = dst_c_idx * k_stride[0] + + src_c_idx * k_stride[1] + + offset_h * k_stride[2] + + offset_w * k_stride[3]; + d += inp[inp_idx] * k[k_idx] } } } |