summaryrefslogtreecommitdiff
path: root/candle-nn/src/group_norm.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/group_norm.rs')
-rw-r--r--candle-nn/src/group_norm.rs83
1 files changed, 83 insertions, 0 deletions
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs
new file mode 100644
index 00000000..ac77db4b
--- /dev/null
+++ b/candle-nn/src/group_norm.rs
@@ -0,0 +1,83 @@
+//! Group Normalization.
+//!
+//! This layer applies Group Normalization over a mini-batch of inputs.
+use candle::{DType, Result, Tensor};
+
+// This group norm version handles both weight and bias so removes the mean.
+#[derive(Debug)]
+pub struct GroupNorm {
+ weight: Tensor,
+ bias: Tensor,
+ eps: f64,
+ num_channels: usize,
+ num_groups: usize,
+}
+
+impl GroupNorm {
+ pub fn new(
+ weight: Tensor,
+ bias: Tensor,
+ num_channels: usize,
+ num_groups: usize,
+ eps: f64,
+ ) -> Result<Self> {
+ if num_channels % num_groups != 0 {
+ candle::bail!(
+ "GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})"
+ )
+ }
+ Ok(Self {
+ weight,
+ bias,
+ eps,
+ num_channels,
+ num_groups,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x_shape = x.dims();
+ if x_shape.len() <= 2 {
+ candle::bail!("input rank for GroupNorm should be at least 3");
+ }
+ let (b_sz, n_channels) = (x_shape[0], x_shape[1]);
+ let hidden_size = x_shape[2..].iter().product::<usize>() * n_channels / self.num_groups;
+ if n_channels != self.num_channels {
+ candle::bail!(
+ "unexpected num-channels in GroupNorm ({n_channels} <> {}",
+ self.num_channels
+ )
+ }
+ let x_dtype = x.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+ let x = x.reshape((b_sz, self.num_groups, hidden_size))?;
+ let x = x.to_dtype(internal_dtype)?;
+ let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
+ let x = x.broadcast_sub(&mean_x)?;
+ let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
+ let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
+ let mut w_dims = vec![1; x_shape.len()];
+ w_dims[1] = n_channels;
+ let weight = self.weight.reshape(w_dims.clone())?;
+ let bias = self.bias.reshape(w_dims)?;
+ x_normed
+ .to_dtype(x_dtype)?
+ .reshape(x_shape)?
+ .broadcast_mul(&weight)?
+ .broadcast_add(&bias)
+ }
+}
+
+pub fn group_norm(
+ num_groups: usize,
+ num_channels: usize,
+ eps: f64,
+ vb: crate::VarBuilder,
+) -> Result<GroupNorm> {
+ let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?;
+ let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?;
+ GroupNorm::new(weight, bias, num_channels, num_groups, eps)
+}