summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs61
1 files changed, 27 insertions, 34 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 210231d8..05c1f4e8 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -997,7 +997,6 @@ impl<'a> Map2 for Conv1D<'a> {
(0, inp_stride) // This value never gets used anyway
};
let k_stride = k_l.stride();
- let k_over_2 = p.k_size / 2;
let l_out = p.l_out();
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
let mut dst = vec![T::zero(); dst_elems];
@@ -1011,18 +1010,16 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_idx = dst_idx + dst_l;
let mut d = T::zero();
for offset in 0..p.k_size {
- let src_l_plus = p.stride * dst_l + offset + k_over_2 - p.padding;
- // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
- if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in {
- let src_l = src_l_plus - k_over_2;
- for src_c_idx in 0..p.c_in {
- let inp_idx =
- inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
- let k_idx = dst_c_idx * k_stride[0]
- + src_c_idx * k_stride[1]
- + offset * k_stride[2];
- d += inp[inp_idx] * k[k_idx]
- }
+ let src_l = (p.stride * dst_l + offset)
+ .saturating_sub(p.padding)
+ .min(p.l_in - 1);
+ for src_c_idx in 0..p.c_in {
+ let inp_idx =
+ inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
+ let k_idx = dst_c_idx * k_stride[0]
+ + src_c_idx * k_stride[1]
+ + offset * k_stride[2];
+ d += inp[inp_idx] * k[k_idx]
}
}
dst[dst_idx] = d
@@ -1064,27 +1061,23 @@ impl<'a> Map2 for Conv2D<'a> {
let mut d = T::zero();
for offset_h in 0..p.k_h {
// TODO: Handle the case where padding is larger than p.k_h / 2.
- let src_h_plus = p.stride * dst_h + offset_h + p.k_h / 2 - p.padding;
- if p.k_h / 2 <= src_h_plus && src_h_plus < p.k_h / 2 + p.i_h {
- let src_h = src_h_plus - p.k_h / 2;
- for offset_w in 0..p.k_w {
- let src_w_plus =
- p.stride * dst_w + offset_w + p.k_w / 2 - p.padding;
- // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
- if p.k_w / 2 <= src_w_plus && src_w_plus < p.k_w / 2 + p.i_w {
- let src_w = src_w_plus - p.k_w / 2;
- for src_c_idx in 0..p.c_in {
- let inp_idx = inp_idx
- + src_c_idx * inp_stride[1]
- + src_h * inp_stride[2]
- + src_w * inp_stride[3];
- let k_idx = dst_c_idx * k_stride[0]
- + src_c_idx * k_stride[1]
- + offset_h * k_stride[2]
- + offset_w * k_stride[3];
- d += inp[inp_idx] * k[k_idx]
- }
- }
+ let src_h = (p.stride * dst_h + offset_h)
+ .saturating_sub(p.padding)
+ .min(p.i_h - 1);
+ for offset_w in 0..p.k_w {
+ let src_w = (p.stride * dst_w + offset_w)
+ .saturating_sub(p.padding)
+ .min(p.i_w - 1);
+ for src_c_idx in 0..p.c_in {
+ let inp_idx = inp_idx
+ + src_c_idx * inp_stride[1]
+ + src_h * inp_stride[2]
+ + src_w * inp_stride[3];
+ let k_idx = dst_c_idx * k_stride[0]
+ + src_c_idx * k_stride[1]
+ + offset_h * k_stride[2]
+ + offset_w * k_stride[3];
+ d += inp[inp_idx] * k[k_idx]
}
}
}