Skip to content

Commit e0d43d9

Browse files
authored
Merge pull request #2120 from pradnya-git-dev/MSTTS
[WIP] Zero-Shot Multi-Speaker Tacotron2
2 parents 24aaa19 + fc892ac commit e0d43d9

File tree

13 files changed

+2582
-64
lines changed

13 files changed

+2582
-64
lines changed

recipes/LibriTTS/README.md

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,25 @@ The LibriTTS dataset is available here: https://www.openslr.org/60/, https://www
66

77
The `libritts_prepare.py` file automatically downloads the dataset if not present and has facilities to provide the names of the subsets to be downloaded.
88

9+
# Zero-Shot Multi-Speaker Tacotron2
10+
The subfolder "TTS/mstacotron2" contains the recipe for training a zero-shot multi-speaker version of the [Tacotron2](https://arxiv.org/abs/1712.05884) model.
11+
To run this recipe, go into the `"TTS/mstacotron2"` folder and run:
12+
13+
```bash
14+
python train.py hparams/train.yaml --data_folder=/path/to/libritts_data --device=cuda:0 --max_grad_norm=1.0
15+
```
16+
17+
Please ensure that you use absolute paths when specifying the data folder.
18+
19+
Training time required on NVIDIA A100 GPU using LibriTTS train-clean-100 and train-clean-360 subsets: ~ 2 hours 54 minutes per epoch
20+
21+
The training logs are available [here](https://www.dropbox.com/sh/ti2vk7sce8f9fgd/AABcDGWCrBvLX_ZQs76mlJRYa?dl=0).
22+
23+
The pre-trained model with an easy-inference interface is available on [HuggingFace](https://huggingface.co/speechbrain/tts-mstacotron2-libritts).
24+
25+
**Please Note**: The current model effectively captures speaker identities. Nevertheless, the synthesized speech quality exhibits some metallic characteristics and may include artifacts like overly long pauses.
26+
We are actively working to enhancing the model and will release updates as soon as improvements are achieved. We warmly welcome contributions from the community to collaboratively make the model even better!
27+
928
# HiFi GAN (Vocoder)
1029
The subfolder "vocoder/hifi_gan/" contains the [HiFi GAN vocoder](https://arxiv.org/pdf/2010.05646.pdf).
1130
The vocoder is a neural network that converts a spectrogram into a waveform (it can be used on top of Tacotron2).
@@ -14,11 +33,13 @@ We suggest using `tensorboard_logger` by setting `use_tensorboard: True` in the
1433

1534
To run this recipe, go into the `"vocoder/hifigan/"` folder and run:
1635

17-
```
36+
```bash
1837
python train.py hparams/train.yaml --data_folder=/path/to/LibriTTS
1938
```
2039

21-
The recipe will automatically download the librispeech dataset and resamples it as specified.
40+
The recipe will automatically download the LibriTTS dataset and resamples it as specified.
41+
42+
Training time required on NVIDIA A100 GPU using LibriTTS train-clean-100 and train-clean-360 subsets: ~ 1 hour 50 minutes per epoch
2243

2344
The training logs and checkpoints are available [here](https://www.dropbox.com/sh/gjs1kslxkxz819q/AABPriN4dOoD1qL7NoIyVk0Oa?dl=0).
2445

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import json
2+
from speechbrain.pretrained import EncoderClassifier, MelSpectrogramEncoder
3+
import torchaudio
4+
import pickle
5+
import logging
6+
import os
7+
from tqdm import tqdm
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def compute_speaker_embeddings(
13+
input_filepaths,
14+
output_file_paths,
15+
data_folder,
16+
spk_emb_encoder_path,
17+
spk_emb_sr,
18+
mel_spec_params,
19+
device,
20+
):
21+
"""This function processes a JSON file to compute the speaker embeddings
22+
23+
Arguments
24+
---------
25+
input_filepaths : list
26+
A list of paths to the JSON files to be processed
27+
output_file_paths : list
28+
A list of paths to the output pickle files corresponding to the input JSON files
29+
data_folder : str
30+
Path to the folder where LibriTTS data is stored
31+
spk_emb_encoder_path : str
32+
Path for the speaker encoder
33+
spk_emb_sr : int
34+
Sample rate used by the speaker embedding encoder
35+
mel_spec_params: dict
36+
Information about mel-spectrogram computation
37+
device : str
38+
Device for to be used for computation
39+
"""
40+
41+
# Checks if this phase is already done (if so, skips it)
42+
if skip(output_file_paths):
43+
logger.info("Preparation completed in previous run, skipping.")
44+
return
45+
46+
# Initializes the speaker encoder
47+
spk_emb_encoder = None
48+
if mel_spec_params["custom_mel_spec_encoder"]:
49+
# To use the custom mel-spectrogram based encoder - for compatibility with future speaker consistency loss work
50+
spk_emb_encoder = MelSpectrogramEncoder.from_hparams(
51+
source=spk_emb_encoder_path, run_opts={"device": device}
52+
)
53+
else:
54+
# To use the speaker encoders available with SpeechBrain
55+
spk_emb_encoder = EncoderClassifier.from_hparams(
56+
source=spk_emb_encoder_path, run_opts={"device": device}
57+
)
58+
59+
# Processes data manifests files to create corresponding speaker embedding files
60+
for i in range(len(input_filepaths)):
61+
logger.info(f"Creating {output_file_paths[i]}.")
62+
63+
speaker_embeddings = dict() # Holds speaker embeddings
64+
65+
json_file = open(input_filepaths[i])
66+
json_data = json.load(json_file)
67+
68+
# Processes all utterances in the data manifest file
69+
for utt_id, utt_data in tqdm(json_data.items()):
70+
utt_wav_path = utt_data["wav"]
71+
utt_wav_path = utt_wav_path.replace("{data_root}", data_folder)
72+
73+
# Loads and resamples waveforms if required
74+
signal, sig_sr = torchaudio.load(utt_wav_path)
75+
if sig_sr != spk_emb_sr:
76+
signal = torchaudio.functional.resample(
77+
signal, sig_sr, spk_emb_sr
78+
)
79+
signal = signal.to(device)
80+
81+
# Computes the speaker embedding
82+
if mel_spec_params["custom_mel_spec_encoder"]:
83+
spk_emb = spk_emb_encoder.encode_waveform(signal)
84+
else:
85+
spk_emb = spk_emb_encoder.encode_batch(signal)
86+
87+
spk_emb = spk_emb.squeeze()
88+
spk_emb = spk_emb.detach()
89+
90+
speaker_embeddings[utt_id] = spk_emb.cpu()
91+
92+
# Stores the speaker embeddings at the destination
93+
with open(output_file_paths[i], "wb") as output_file:
94+
pickle.dump(
95+
speaker_embeddings,
96+
output_file,
97+
protocol=pickle.HIGHEST_PROTOCOL,
98+
)
99+
100+
logger.info(f"Created {output_file_paths[i]}.")
101+
102+
103+
def skip(filepaths):
104+
"""
105+
Detects if the data preparation has been already done.
106+
If the preparation has been done, we can skip it.
107+
Returns
108+
-------
109+
bool
110+
if True, the preparation phase can be skipped.
111+
if False, it must be done.
112+
"""
113+
for filepath in filepaths:
114+
if not os.path.isfile(filepath):
115+
return False
116+
return True

0 commit comments

Comments
 (0)