diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-07 22:39:59 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-07 22:39:59 +0100 |
commit | 7920b45c8ac737b67e23f04297f6bd7e4860f373 (patch) | |
tree | 1bfc4e3cf8bc84f7029143a61f8dd91b3b27fbdf /candle-examples/examples/musicgen | |
parent | d4a45c936a3acb79f89d5a18bcfa3ef34f11ae45 (diff) | |
download | candle-7920b45c8ac737b67e23f04297f6bd7e4860f373.tar.gz candle-7920b45c8ac737b67e23f04297f6bd7e4860f373.tar.bz2 candle-7920b45c8ac737b67e23f04297f6bd7e4860f373.zip |
Support for timegroupnorm in encodec. (#1291)
Diffstat (limited to 'candle-examples/examples/musicgen')
-rw-r--r-- | candle-examples/examples/musicgen/encodec_model.rs | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index 095c90a9..60149e45 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder}; #[derive(Debug, Clone, PartialEq)] enum NormType { WeightNorm, + TimeGroupNorm, None, } @@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d { struct EncodecConv1d { causal: bool, conv: Conv1d, + norm: Option<candle_nn::GroupNorm>, } impl EncodecConv1d { @@ -292,7 +294,7 @@ impl EncodecConv1d { }, vb.pp("conv"), )?, - NormType::None => conv1d( + NormType::None | NormType::TimeGroupNorm => conv1d( in_c, out_c, kernel_size, @@ -305,9 +307,17 @@ impl EncodecConv1d { vb.pp("conv"), )?, }; + let norm = match cfg.norm_type { + NormType::None | NormType::WeightNorm => None, + NormType::TimeGroupNorm => { + let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?; + Some(gn) + } + }; Ok(Self { causal: cfg.use_causal_conv, conv, + norm, }) } } @@ -316,8 +326,10 @@ impl Module for EncodecConv1d { fn forward(&self, xs: &Tensor) -> Result<Tensor> { // TODO: padding, depending on causal. let xs = self.conv.forward(xs)?; - // If we add support for NormType "time_group_norm", we should add some normalization here. - Ok(xs) + match &self.norm { + None => Ok(xs), + Some(norm) => xs.apply(norm), + } } } |