diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-05 13:06:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-05 13:06:33 +0100 |
commit | 93896f6596e44285f6250f4966ada8c08fa85f09 (patch) | |
tree | fee5a01b56231a6d1472fd925f76c73aa8b93ac0 /candle-core/src/conv.rs | |
parent | d8f75ceeaa4702b641a9f71ec348fc54a32f4cd7 (diff) | |
parent | bce28ab7938b27931fd51e59c8bcad37038e0337 (diff) | |
download | candle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.gz candle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.bz2 candle-93896f6596e44285f6250f4966ada8c08fa85f09.zip |
Merge branch 'main' into upgrade_bert
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r-- | candle-core/src/conv.rs | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs new file mode 100644 index 00000000..041bb6fb --- /dev/null +++ b/candle-core/src/conv.rs @@ -0,0 +1,27 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ParamsConv1D { + pub(crate) b_size: Option<usize>, + // Maybe we should have a version without l_in as this bit depends on the input and not only on + // the weights. + pub(crate) l_in: usize, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) k_size: usize, + pub(crate) padding: usize, + pub(crate) stride: usize, +} + +impl ParamsConv1D { + pub(crate) fn l_out(&self) -> usize { + let dilation = 1; + (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_dims(&self) -> Vec<usize> { + let l_out = self.l_out(); + match self.b_size { + None => vec![self.c_out, l_out], + Some(n) => vec![n, self.c_out, l_out], + } + } +} |