summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/group_norm.rs47
-rw-r--r--candle-nn/src/ops.rs6
2 files changed, 45 insertions, 8 deletions
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs
index 4b9bed73..e277ae85 100644
--- a/candle-nn/src/group_norm.rs
+++ b/candle-nn/src/group_norm.rs
@@ -1,10 +1,9 @@
//! Group Normalization.
//!
//! This layer applies Group Normalization over a mini-batch of inputs.
-use candle::{Result, Tensor};
+use candle::{DType, Result, Tensor};
// This group norm version handles both weight and bias so removes the mean.
-#[allow(dead_code)]
#[derive(Debug)]
pub struct GroupNorm {
weight: Tensor,
@@ -21,18 +20,50 @@ impl GroupNorm {
num_channels: usize,
num_groups: usize,
eps: f64,
- ) -> Self {
- Self {
+ ) -> Result<Self> {
+ if num_channels % num_groups != 0 {
+ candle::bail!(
+ "GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})"
+ )
+ }
+ Ok(Self {
weight,
bias,
eps,
num_channels,
num_groups,
- }
+ })
}
- pub fn forward(&self, _: &Tensor) -> Result<Tensor> {
- todo!()
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x_shape = x.dims();
+ if x_shape.len() <= 2 {
+ candle::bail!("input rank for GroupNorm should be at least 3");
+ }
+ let (b_sz, n_channels) = (x_shape[0], x_shape[1]);
+ let hidden_size = x_shape[2..].iter().product::<usize>() * n_channels / self.num_groups;
+ if n_channels != self.num_channels {
+ candle::bail!(
+ "unexpected num-channels in GroupNorm ({n_channels} <> {}",
+ self.num_channels
+ )
+ }
+ let x_dtype = x.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+ let x = x.reshape((b_sz, self.num_groups, hidden_size))?;
+ let x = x.to_dtype(internal_dtype)?;
+ let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
+ let x = x.broadcast_sub(&mean_x)?;
+ let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
+ let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
+ x_normed
+ .to_dtype(x_dtype)?
+ .broadcast_mul(&self.weight)?
+ .broadcast_add(&self.bias)?
+ .reshape(x_shape)
}
}
@@ -44,5 +75,5 @@ pub fn group_norm(
) -> Result<GroupNorm> {
let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?;
let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?;
- Ok(GroupNorm::new(weight, bias, num_channels, num_groups, eps))
+ GroupNorm::new(weight, bias, num_channels, num_groups, eps)
}
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 29cc6973..397674f3 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -34,5 +34,11 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
}
pub fn silu(xs: &Tensor) -> Result<Tensor> {
+ // TODO: Should we have a specialized op for this?
xs / (xs.neg()?.exp()? + 1.0)?
}
+
+pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
+ // TODO: Should we have a specialized op for this?
+ (xs.neg()?.exp()? + 1.0)?.recip()
+}