summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen/musicgen_model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/musicgen/musicgen_model.rs')
-rw-r--r--candle-examples/examples/musicgen/musicgen_model.rs39
1 files changed, 32 insertions, 7 deletions
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs
index c6b52fde..03e96614 100644
--- a/candle-examples/examples/musicgen/musicgen_model.rs
+++ b/candle-examples/examples/musicgen/musicgen_model.rs
@@ -1,10 +1,9 @@
-use crate::encodec_model;
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder,
};
-use candle_transformers::models::t5;
+use candle_transformers::models::{encodec, t5};
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
@@ -372,7 +371,7 @@ impl MusicgenForCausalLM {
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
pub text_encoder: t5::T5EncoderModel,
- pub audio_encoder: crate::encodec_model::EncodecModel,
+ pub audio_encoder: encodec::Model,
pub decoder: MusicgenForCausalLM,
cfg: GenConfig,
}
@@ -381,15 +380,42 @@ pub struct MusicgenForConditionalGeneration {
pub struct GenConfig {
musicgen: Config,
t5: t5::Config,
- encodec: crate::encodec_model::Config,
+ encodec: encodec::Config,
}
impl GenConfig {
pub fn small() -> Self {
+ // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
+ let encodec = encodec::Config {
+ audio_channels: 1,
+ chunk_length_s: None,
+ codebook_dim: Some(128),
+ codebook_size: 2048,
+ compress: 2,
+ dilation_growth_rate: 2,
+ hidden_size: 128,
+ kernel_size: 7,
+ last_kernel_size: 7,
+ norm_type: encodec::NormType::WeightNorm,
+ normalize: false,
+ num_filters: 64,
+ num_lstm_layers: 2,
+ num_residual_layers: 1,
+ overlap: None,
+ // This should be Reflect and not Replicate but Reflect does not work yet.
+ pad_mode: encodec::PadMode::Replicate,
+ residual_kernel_size: 3,
+ sampling_rate: 32_000,
+ target_bandwidths: vec![2.2],
+ trim_right_ratio: 1.0,
+ upsampling_ratios: vec![8, 5, 4, 4],
+ use_causal_conv: false,
+ use_conv_shortcut: false,
+ };
Self {
musicgen: Config::musicgen_small(),
t5: t5::Config::musicgen_small(),
- encodec: encodec_model::Config::musicgen_small(),
+ encodec,
}
}
}
@@ -401,8 +427,7 @@ impl MusicgenForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
- let audio_encoder =
- encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
+ let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
Ok(Self {
text_encoder,