diff options
author | laurent <laurent.mazare@gmail.com> | 2023-07-04 11:15:45 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-07-04 11:15:45 +0100 |
commit | a424d95473ea9268ffb1dde4d73ce0cff9904845 (patch) | |
tree | 88064cc2c8cb2b12a9c5ab3b3c3ac9a70df798b6 /candle-core/src/conv.rs | |
parent | 3aac1047fec43a4d756ae4e60a8ae82f7c3e636e (diff) | |
download | candle-a424d95473ea9268ffb1dde4d73ce0cff9904845.tar.gz candle-a424d95473ea9268ffb1dde4d73ce0cff9904845.tar.bz2 candle-a424d95473ea9268ffb1dde4d73ce0cff9904845.zip |
Add more of the conv1d op.
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r-- | candle-core/src/conv.rs | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs new file mode 100644 index 00000000..90bb5229 --- /dev/null +++ b/candle-core/src/conv.rs @@ -0,0 +1,24 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ParamsConv1D { + pub(crate) b_size: Option<usize>, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) k_size: usize, + pub(crate) padding: usize, + pub(crate) stride: usize, +} + +impl ParamsConv1D { + pub(crate) fn l_out(&self, l_in: usize) -> usize { + let dilation = 1; + (l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_dims(&self, l_in: usize) -> Vec<usize> { + let l_out = self.l_out(l_in); + match self.b_size { + None => vec![self.c_out, l_out], + Some(n) => vec![n, self.c_out, l_out], + } + } +} |