summaryrefslogtreecommitdiff
path: root/candle-nn/src/group_norm.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/group_norm.rs')
-rw-r--r--candle-nn/src/group_norm.rs4
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs
index ac77db4b..e85c4379 100644
--- a/candle-nn/src/group_norm.rs
+++ b/candle-nn/src/group_norm.rs
@@ -34,8 +34,10 @@ impl GroupNorm {
num_groups,
})
}
+}
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+impl crate::Module for GroupNorm {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_shape = x.dims();
if x_shape.len() <= 2 {
candle::bail!("input rank for GroupNorm should be at least 3");