summaryrefslogtreecommitdiff
path: root/candle-nn/src/conv.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/conv.rs')
-rw-r--r--candle-nn/src/conv.rs8
1 files changed, 6 insertions, 2 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index 67a80417..5057d2ef 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -35,8 +35,10 @@ impl Conv1d {
pub fn config(&self) -> &Conv1dConfig {
&self.config
}
+}
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+impl crate::Module for Conv1d {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),
@@ -84,8 +86,10 @@ impl Conv2d {
pub fn config(&self) -> &Conv2dConfig {
&self.config
}
+}
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+impl crate::Module for Conv2d {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),