summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cpu_backend.rs57
1 files changed, 52 insertions, 5 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index c997d767..0c4e4597 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1039,12 +1039,59 @@ impl<'a> Map2 for Conv2D<'a> {
const OP: &'static str = "conv2d";
fn f<T: 'static + num_traits::NumAssign + Copy>(
&self,
- _inp: &[T],
- _inp_l: &Layout,
- _k: &[T],
- _k_l: &Layout,
+ inp: &[T],
+ inp_l: &Layout,
+ k: &[T],
+ k_l: &Layout,
) -> Result<Vec<T>> {
- todo!()
+ let p = self.0;
+ let inp = &inp[inp_l.start_offset()..];
+ let inp_stride = inp_l.stride();
+ let k = &k[k_l.start_offset()..];
+ let k_stride = k_l.stride();
+ let (out_h, out_w) = (p.out_h(), p.out_w());
+
+ let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
+ for b_idx in 0..p.b_size {
+ let inp_idx = b_idx * inp_stride[0];
+ let dst_idx = b_idx * p.c_out * out_h * out_w;
+ for dst_c_idx in 0..p.c_out {
+ let dst_idx = dst_idx + dst_c_idx * out_h * out_w;
+ for dst_h in 0..out_h {
+ let dst_idx = dst_idx + dst_h * out_w;
+ for dst_w in 0..out_h {
+ let dst_idx = dst_idx + dst_w;
+ let mut d = T::zero();
+ for offset_h in 0..p.k_h {
+ let src_h_plus = p.stride * dst_h + offset_h;
+ if p.k_h / 2 <= src_h_plus && src_h_plus < p.k_h / 2 + p.i_h {
+ let src_h = src_h_plus - p.k_h / 2;
+ for offset_w in 0..p.k_w {
+ let src_w_plus = p.stride * dst_w + offset_w;
+ // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
+ if p.k_w / 2 <= src_w_plus && src_w_plus < p.k_w / 2 + p.i_w {
+ let src_w = src_w_plus - p.k_w / 2;
+ for src_c_idx in 0..p.c_in {
+ let inp_idx = inp_idx
+ + src_c_idx * inp_stride[1]
+ + src_h * inp_stride[2]
+ + src_w * inp_stride[3];
+ let k_idx = dst_c_idx * k_stride[0]
+ + src_c_idx * k_stride[1]
+ + offset_h * k_stride[2]
+ + offset_w * k_stride[3];
+ d += inp[inp_idx] * k[k_idx]
+ }
+ }
+ }
+ }
+ }
+ dst[dst_idx] = d
+ }
+ }
+ }
+ }
+ Ok(dst)
}
}