diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-18 19:42:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-18 20:42:08 +0200 |
commit | 58197e189657b6587a254882abdb232e83e86848 (patch) | |
tree | 01dbed067341d47e933b821a1b33100524611a50 /candle-examples/examples/parler-tts/decode.py | |
parent | 736d8eb7521dd48e777827848f2b9ed8a7473571 (diff) | |
download | candle-58197e189657b6587a254882abdb232e83e86848.tar.gz candle-58197e189657b6587a254882abdb232e83e86848.tar.bz2 candle-58197e189657b6587a254882abdb232e83e86848.zip |
parler-tts support (#2431)
* Start sketching parler-tts support.
* Implement the attention.
* Add the example code.
* Fix the example.
* Add the description + t5 encode it.
* More of the parler forward pass.
* Fix the positional embeddings.
* Support random sampling in generation.
* Handle EOS.
* Add the python decoder.
* Proper causality mask.
Diffstat (limited to 'candle-examples/examples/parler-tts/decode.py')
-rw-r--r-- | candle-examples/examples/parler-tts/decode.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-examples/examples/parler-tts/decode.py b/candle-examples/examples/parler-tts/decode.py new file mode 100644 index 00000000..b79ebda1 --- /dev/null +++ b/candle-examples/examples/parler-tts/decode.py @@ -0,0 +1,29 @@ +import torch +import torchaudio +from safetensors.torch import load_file +from parler_tts import DACModel + +tensors = load_file("out.safetensors") +dac_model = DACModel.from_pretrained("parler-tts/dac_44khZ_8kbps") +output_ids = tensors["codes"][None, None] +print(output_ids, "\n", output_ids.shape) +batch_size = 1 +with torch.no_grad(): + output_values = [] + for sample_id in range(batch_size): + sample = output_ids[:, sample_id] + sample_mask = (sample >= dac_model.config.codebook_size).sum(dim=(0, 1)) == 0 + if sample_mask.sum() > 0: + sample = sample[:, :, sample_mask] + sample = dac_model.decode(sample[None, ...], [None]).audio_values + output_values.append(sample.transpose(0, 2)) + else: + output_values.append(torch.zeros((1, 1, 1)).to(dac_model.device)) + output_lengths = [audio.shape[0] for audio in output_values] + pcm = ( + torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0) + .squeeze(-1) + .squeeze(-1) + ) +print(pcm.shape, pcm.dtype) +torchaudio.save("out.wav", pcm.cpu(), sample_rate=44100) |