summaryrefslogtreecommitdiff
path: root/candle-examples/examples/parler-tts/decode.py
blob: 8942d32eeb474803d26ad432c159339c1fbccf37 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torchaudio
from safetensors.torch import load_file
from parler_tts import DACModel

tensors = load_file("out.safetensors")
dac_model = DACModel.from_pretrained("parler-tts/dac_44khZ_8kbps")
print(dac_model.model)
output_ids = tensors["codes"][None, None]
print(output_ids, "\n", output_ids.shape)
batch_size = 1
with torch.no_grad():
    output_values = []
    for sample_id in range(batch_size):
        sample = output_ids[:, sample_id]
        sample_mask = (sample >= dac_model.config.codebook_size).sum(dim=(0, 1)) == 0
        if sample_mask.sum() > 0:
            sample = sample[:, :, sample_mask]
            sample = dac_model.decode(sample[None, ...], [None]).audio_values
            output_values.append(sample.transpose(0, 2))
        else:
            output_values.append(torch.zeros((1, 1, 1)).to(dac_model.device))
    output_lengths = [audio.shape[0] for audio in output_values]
    pcm = (
        torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
        .squeeze(-1)
        .squeeze(-1)
    )
print(pcm.shape, pcm.dtype)
torchaudio.save("out.wav", pcm.cpu(), sample_rate=44100)