diff options
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r-- | candle-core/src/conv.rs | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 4cf9d0ad..30799459 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -25,3 +25,32 @@ impl ParamsConv1D { } } } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParamsConv2D { + pub(crate) b_size: usize, + pub(crate) i_h: usize, + pub(crate) i_w: usize, + pub(crate) k_h: usize, + pub(crate) k_w: usize, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) padding: usize, + pub(crate) stride: usize, +} + +impl ParamsConv2D { + pub(crate) fn out_h(&self) -> usize { + let dilation = 1; + (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_w(&self) -> usize { + let dilation = 1; + (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_dims(&self) -> Vec<usize> { + vec![self.b_size, self.c_out, self.out_h(), self.out_w()] + } +} |