diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/group_norm.rs | 47 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 6 |
2 files changed, 45 insertions, 8 deletions
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs index 4b9bed73..e277ae85 100644 --- a/candle-nn/src/group_norm.rs +++ b/candle-nn/src/group_norm.rs @@ -1,10 +1,9 @@ //! Group Normalization. //! //! This layer applies Group Normalization over a mini-batch of inputs. -use candle::{Result, Tensor}; +use candle::{DType, Result, Tensor}; // This group norm version handles both weight and bias so removes the mean. -#[allow(dead_code)] #[derive(Debug)] pub struct GroupNorm { weight: Tensor, @@ -21,18 +20,50 @@ impl GroupNorm { num_channels: usize, num_groups: usize, eps: f64, - ) -> Self { - Self { + ) -> Result<Self> { + if num_channels % num_groups != 0 { + candle::bail!( + "GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})" + ) + } + Ok(Self { weight, bias, eps, num_channels, num_groups, - } + }) } - pub fn forward(&self, _: &Tensor) -> Result<Tensor> { - todo!() + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x_shape = x.dims(); + if x_shape.len() <= 2 { + candle::bail!("input rank for GroupNorm should be at least 3"); + } + let (b_sz, n_channels) = (x_shape[0], x_shape[1]); + let hidden_size = x_shape[2..].iter().product::<usize>() * n_channels / self.num_groups; + if n_channels != self.num_channels { + candle::bail!( + "unexpected num-channels in GroupNorm ({n_channels} <> {}", + self.num_channels + ) + } + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let x = x.reshape((b_sz, self.num_groups, hidden_size))?; + let x = x.to_dtype(internal_dtype)?; + let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&self.weight)? + .broadcast_add(&self.bias)? + .reshape(x_shape) } } @@ -44,5 +75,5 @@ pub fn group_norm( ) -> Result<GroupNorm> { let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?; let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?; - Ok(GroupNorm::new(weight, bias, num_channels, num_groups, eps)) + GroupNorm::new(weight, bias, num_channels, num_groups, eps) } diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 29cc6973..397674f3 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -34,5 +34,11 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> { } pub fn silu(xs: &Tensor) -> Result<Tensor> { + // TODO: Should we have a specialized op for this? xs / (xs.neg()?.exp()? + 1.0)? } + +pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { + // TODO: Should we have a specialized op for this? + (xs.neg()?.exp()? + 1.0)?.recip() +} |