summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/group_norm.rs12
1 files changed, 8 insertions, 4 deletions
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs
index e277ae85..ac77db4b 100644
--- a/candle-nn/src/group_norm.rs
+++ b/candle-nn/src/group_norm.rs
@@ -59,17 +59,21 @@ impl GroupNorm {
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
+ let mut w_dims = vec![1; x_shape.len()];
+ w_dims[1] = n_channels;
+ let weight = self.weight.reshape(w_dims.clone())?;
+ let bias = self.bias.reshape(w_dims)?;
x_normed
.to_dtype(x_dtype)?
- .broadcast_mul(&self.weight)?
- .broadcast_add(&self.bias)?
- .reshape(x_shape)
+ .reshape(x_shape)?
+ .broadcast_mul(&weight)?
+ .broadcast_add(&bias)
}
}
pub fn group_norm(
- num_channels: usize,
num_groups: usize,
+ num_channels: usize,
eps: f64,
vb: crate::VarBuilder,
) -> Result<GroupNorm> {