summaryrefslogtreecommitdiff
path: root/candle-core/src/conv.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-07-04 11:15:45 +0100
committerlaurent <laurent.mazare@gmail.com>2023-07-04 11:15:45 +0100
commita424d95473ea9268ffb1dde4d73ce0cff9904845 (patch)
tree88064cc2c8cb2b12a9c5ab3b3c3ac9a70df798b6 /candle-core/src/conv.rs
parent3aac1047fec43a4d756ae4e60a8ae82f7c3e636e (diff)
downloadcandle-a424d95473ea9268ffb1dde4d73ce0cff9904845.tar.gz
candle-a424d95473ea9268ffb1dde4d73ce0cff9904845.tar.bz2
candle-a424d95473ea9268ffb1dde4d73ce0cff9904845.zip
Add more of the conv1d op.
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r--candle-core/src/conv.rs24
1 files changed, 24 insertions, 0 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
new file mode 100644
index 00000000..90bb5229
--- /dev/null
+++ b/candle-core/src/conv.rs
@@ -0,0 +1,24 @@
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) struct ParamsConv1D {
+ pub(crate) b_size: Option<usize>,
+ pub(crate) c_out: usize,
+ pub(crate) c_in: usize,
+ pub(crate) k_size: usize,
+ pub(crate) padding: usize,
+ pub(crate) stride: usize,
+}
+
+impl ParamsConv1D {
+ pub(crate) fn l_out(&self, l_in: usize) -> usize {
+ let dilation = 1;
+ (l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_dims(&self, l_in: usize) -> Vec<usize> {
+ let l_out = self.l_out(l_in);
+ match self.b_size {
+ None => vec![self.c_out, l_out],
+ Some(n) => vec![n, self.c_out, l_out],
+ }
+ }
+}