summaryrefslogtreecommitdiff
path: root/candle-examples/examples/parler-tts/decode.py
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-08-18 19:42:08 +0100
committerGitHub <noreply@github.com>2024-08-18 20:42:08 +0200
commit58197e189657b6587a254882abdb232e83e86848 (patch)
tree01dbed067341d47e933b821a1b33100524611a50 /candle-examples/examples/parler-tts/decode.py
parent736d8eb7521dd48e777827848f2b9ed8a7473571 (diff)
downloadcandle-58197e189657b6587a254882abdb232e83e86848.tar.gz
candle-58197e189657b6587a254882abdb232e83e86848.tar.bz2
candle-58197e189657b6587a254882abdb232e83e86848.zip
parler-tts support (#2431)
* Start sketching parler-tts support. * Implement the attention. * Add the example code. * Fix the example. * Add the description + t5 encode it. * More of the parler forward pass. * Fix the positional embeddings. * Support random sampling in generation. * Handle EOS. * Add the python decoder. * Proper causality mask.
Diffstat (limited to 'candle-examples/examples/parler-tts/decode.py')
-rw-r--r--candle-examples/examples/parler-tts/decode.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-examples/examples/parler-tts/decode.py b/candle-examples/examples/parler-tts/decode.py
new file mode 100644
index 00000000..b79ebda1
--- /dev/null
+++ b/candle-examples/examples/parler-tts/decode.py
@@ -0,0 +1,29 @@
+import torch
+import torchaudio
+from safetensors.torch import load_file
+from parler_tts import DACModel
+
+tensors = load_file("out.safetensors")
+dac_model = DACModel.from_pretrained("parler-tts/dac_44khZ_8kbps")
+output_ids = tensors["codes"][None, None]
+print(output_ids, "\n", output_ids.shape)
+batch_size = 1
+with torch.no_grad():
+ output_values = []
+ for sample_id in range(batch_size):
+ sample = output_ids[:, sample_id]
+ sample_mask = (sample >= dac_model.config.codebook_size).sum(dim=(0, 1)) == 0
+ if sample_mask.sum() > 0:
+ sample = sample[:, :, sample_mask]
+ sample = dac_model.decode(sample[None, ...], [None]).audio_values
+ output_values.append(sample.transpose(0, 2))
+ else:
+ output_values.append(torch.zeros((1, 1, 1)).to(dac_model.device))
+ output_lengths = [audio.shape[0] for audio in output_values]
+ pcm = (
+ torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
+ .squeeze(-1)
+ .squeeze(-1)
+ )
+print(pcm.shape, pcm.dtype)
+torchaudio.save("out.wav", pcm.cpu(), sample_rate=44100)