summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-07 22:39:59 +0100
committerGitHub <noreply@github.com>2023-11-07 22:39:59 +0100
commit7920b45c8ac737b67e23f04297f6bd7e4860f373 (patch)
tree1bfc4e3cf8bc84f7029143a61f8dd91b3b27fbdf /candle-examples/examples/musicgen
parentd4a45c936a3acb79f89d5a18bcfa3ef34f11ae45 (diff)
downloadcandle-7920b45c8ac737b67e23f04297f6bd7e4860f373.tar.gz
candle-7920b45c8ac737b67e23f04297f6bd7e4860f373.tar.bz2
candle-7920b45c8ac737b67e23f04297f6bd7e4860f373.zip
Support for timegroupnorm in encodec. (#1291)
Diffstat (limited to 'candle-examples/examples/musicgen')
-rw-r--r--candle-examples/examples/musicgen/encodec_model.rs18
1 files changed, 15 insertions, 3 deletions
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs
index 095c90a9..60149e45 100644
--- a/candle-examples/examples/musicgen/encodec_model.rs
+++ b/candle-examples/examples/musicgen/encodec_model.rs
@@ -8,6 +8,7 @@ use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
#[derive(Debug, Clone, PartialEq)]
enum NormType {
WeightNorm,
+ TimeGroupNorm,
None,
}
@@ -268,6 +269,7 @@ impl Module for EncodecConvTranspose1d {
struct EncodecConv1d {
causal: bool,
conv: Conv1d,
+ norm: Option<candle_nn::GroupNorm>,
}
impl EncodecConv1d {
@@ -292,7 +294,7 @@ impl EncodecConv1d {
},
vb.pp("conv"),
)?,
- NormType::None => conv1d(
+ NormType::None | NormType::TimeGroupNorm => conv1d(
in_c,
out_c,
kernel_size,
@@ -305,9 +307,17 @@ impl EncodecConv1d {
vb.pp("conv"),
)?,
};
+ let norm = match cfg.norm_type {
+ NormType::None | NormType::WeightNorm => None,
+ NormType::TimeGroupNorm => {
+ let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
+ Some(gn)
+ }
+ };
Ok(Self {
causal: cfg.use_causal_conv,
conv,
+ norm,
})
}
}
@@ -316,8 +326,10 @@ impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
- // If we add support for NormType "time_group_norm", we should add some normalization here.
- Ok(xs)
+ match &self.norm {
+ None => Ok(xs),
+ Some(norm) => xs.apply(norm),
+ }
}
}