diff options
Diffstat (limited to 'candle-examples/examples/musicgen/encodec_model.rs')
-rw-r--r-- | candle-examples/examples/musicgen/encodec_model.rs | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index 095c90a9..60149e45 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder}; #[derive(Debug, Clone, PartialEq)] enum NormType { WeightNorm, + TimeGroupNorm, None, } @@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d { struct EncodecConv1d { causal: bool, conv: Conv1d, + norm: Option<candle_nn::GroupNorm>, } impl EncodecConv1d { @@ -292,7 +294,7 @@ impl EncodecConv1d { }, vb.pp("conv"), )?, - NormType::None => conv1d( + NormType::None | NormType::TimeGroupNorm => conv1d( in_c, out_c, kernel_size, @@ -305,9 +307,17 @@ impl EncodecConv1d { vb.pp("conv"), )?, }; + let norm = match cfg.norm_type { + NormType::None | NormType::WeightNorm => None, + NormType::TimeGroupNorm => { + let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?; + Some(gn) + } + }; Ok(Self { causal: cfg.use_causal_conv, conv, + norm, }) } } @@ -316,8 +326,10 @@ impl Module for EncodecConv1d { fn forward(&self, xs: &Tensor) -> Result<Tensor> { // TODO: padding, depending on causal. let xs = self.conv.forward(xs)?; - // If we add support for NormType "time_group_norm", we should add some normalization here. - Ok(xs) + match &self.norm { + None => Ok(xs), + Some(norm) => xs.apply(norm), + } } } |