summaryrefslogtreecommitdiff
path: root/candle-core/src/conv.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r--candle-core/src/conv.rs29
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 4cf9d0ad..30799459 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -25,3 +25,32 @@ impl ParamsConv1D {
}
}
}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct ParamsConv2D {
+ pub(crate) b_size: usize,
+ pub(crate) i_h: usize,
+ pub(crate) i_w: usize,
+ pub(crate) k_h: usize,
+ pub(crate) k_w: usize,
+ pub(crate) c_out: usize,
+ pub(crate) c_in: usize,
+ pub(crate) padding: usize,
+ pub(crate) stride: usize,
+}
+
+impl ParamsConv2D {
+ pub(crate) fn out_h(&self) -> usize {
+ let dilation = 1;
+ (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_w(&self) -> usize {
+ let dilation = 1;
+ (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_dims(&self) -> Vec<usize> {
+ vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
+ }
+}