diff options
Diffstat (limited to 'candle-nn/src/group_norm.rs')
-rw-r--r-- | candle-nn/src/group_norm.rs | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs index ac77db4b..e85c4379 100644 --- a/candle-nn/src/group_norm.rs +++ b/candle-nn/src/group_norm.rs @@ -34,8 +34,10 @@ impl GroupNorm { num_groups, }) } +} - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { +impl crate::Module for GroupNorm { + fn forward(&self, x: &Tensor) -> Result<Tensor> { let x_shape = x.dims(); if x_shape.len() <= 2 { candle::bail!("input rank for GroupNorm should be at least 3"); |