diff options
Diffstat (limited to 'candle-nn/src/conv.rs')
-rw-r--r-- | candle-nn/src/conv.rs | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 67a80417..5057d2ef 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -35,8 +35,10 @@ impl Conv1d { pub fn config(&self) -> &Conv1dConfig { &self.config } +} - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { +impl crate::Module for Conv1d { + fn forward(&self, x: &Tensor) -> Result<Tensor> { let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?; match &self.bias { None => Ok(x), @@ -84,8 +86,10 @@ impl Conv2d { pub fn config(&self) -> &Conv2dConfig { &self.config } +} - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { +impl crate::Module for Conv2d { + fn forward(&self, x: &Tensor) -> Result<Tensor> { let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?; match &self.bias { None => Ok(x), |