summaryrefslogtreecommitdiff
path: root/candle-core/src/conv.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-05 13:06:33 +0100
committerGitHub <noreply@github.com>2023-07-05 13:06:33 +0100
commit93896f6596e44285f6250f4966ada8c08fa85f09 (patch)
treefee5a01b56231a6d1472fd925f76c73aa8b93ac0 /candle-core/src/conv.rs
parentd8f75ceeaa4702b641a9f71ec348fc54a32f4cd7 (diff)
parentbce28ab7938b27931fd51e59c8bcad37038e0337 (diff)
downloadcandle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.gz
candle-93896f6596e44285f6250f4966ada8c08fa85f09.tar.bz2
candle-93896f6596e44285f6250f4966ada8c08fa85f09.zip
Merge branch 'main' into upgrade_bert
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r--candle-core/src/conv.rs27
1 files changed, 27 insertions, 0 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
new file mode 100644
index 00000000..041bb6fb
--- /dev/null
+++ b/candle-core/src/conv.rs
@@ -0,0 +1,27 @@
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) struct ParamsConv1D {
+ pub(crate) b_size: Option<usize>,
+ // Maybe we should have a version without l_in as this bit depends on the input and not only on
+ // the weights.
+ pub(crate) l_in: usize,
+ pub(crate) c_out: usize,
+ pub(crate) c_in: usize,
+ pub(crate) k_size: usize,
+ pub(crate) padding: usize,
+ pub(crate) stride: usize,
+}
+
+impl ParamsConv1D {
+ pub(crate) fn l_out(&self) -> usize {
+ let dilation = 1;
+ (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
+ }
+
+ pub(crate) fn out_dims(&self) -> Vec<usize> {
+ let l_out = self.l_out();
+ match self.b_size {
+ None => vec![self.c_out, l_out],
+ Some(n) => vec![n, self.c_out, l_out],
+ }
+ }
+}