diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..dad09d34d --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 120 +ignore = E203, E701, E704, W503 diff --git a/.gitignore b/.gitignore index c0494b5a5..c2007963a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +/.vscode +/data +/saved_models +/test.py + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/ChatTTS/train/__init__.py b/ChatTTS/train/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ChatTTS/train/dataset.py b/ChatTTS/train/dataset.py new file mode 100644 index 000000000..1ab960391 --- /dev/null +++ b/ChatTTS/train/dataset.py @@ -0,0 +1,500 @@ +import os +import json +import tarfile +import io +import logging +import tqdm +import abc +import typing + +import torch.utils.data +import torchaudio +from torchvision.datasets.utils import download_url +import transformers + +from ChatTTS.norm import Normalizer + + +class LazyDataType(typing.TypedDict): + filepath: str + speaker: str + lang: str + text: str + + +class DataType(LazyDataType): + text_input_ids: torch.Tensor # (batch_size, text_len) + text_attention_mask: torch.Tensor # (batch_size, text_len) + waveforms: torch.Tensor # (batch_size, time) + waveform_attention_mask: torch.Tensor # (batch_size, time) + + +class XzListTarKwargsType(typing.TypedDict): + tokenizer: typing.NotRequired[transformers.PreTrainedTokenizer | None] + normalizer: typing.NotRequired[Normalizer | None] + speakers: typing.NotRequired[typing.Iterable[str] | None] + sample_rate: typing.NotRequired[int] + default_speaker: typing.NotRequired[str | None] + default_lang: typing.NotRequired[str | None] + tar_in_memory: typing.NotRequired[bool] + process_ahead: typing.NotRequired[bool] + + +class AudioFolder(torch.utils.data.Dataset, abc.ABC): + def __init__( + self, + root: str | io.TextIOWrapper, + tokenizer: transformers.PreTrainedTokenizer | None = None, + normalizer: Normalizer | None = None, + speakers: typing.Iterable[str] | None = None, + sample_rate: int = 24_000, + default_speaker: str | None = None, + default_lang: str | None = None, + tar_path: str | None = None, + tar_in_memory: bool = False, + process_ahead: bool = False, + ) -> None: + self.root = root + self.sample_rate = sample_rate + self.default_speaker = default_speaker + self.default_lang = default_lang + + self.logger = logging.getLogger(__name__) + self.normalizer = normalizer + self.tokenizer = tokenizer + + # tar -cvf ../Xz.tar * + # tar -xf Xz.tar -C ./Xz + self.tar_path = tar_path + self.tar_file = None + self.tar_io = None + if tar_path is not None: + if tar_in_memory: + with open(tar_path, "rb") as f: + self.tar_io = io.BytesIO(f.read()) + self.tar_file = tarfile.open(fileobj=self.tar_io) + else: + self.tar_file = tarfile.open(tar_path) + + self.lazy_data, self.speakers = self.get_lazy_data(root, speakers) + + self.text_input_ids: dict[int, torch.Tensor] = {} + self.waveforms: dict[int, torch.Tensor] = {} + if process_ahead: + print("Processing data ...") + for n, item in enumerate(tqdm.tqdm(self.lazy_data)): + self.waveforms[n] = self.preprocess_audio(item["filepath"]) + self.text_input_ids[n] = self.preprocess_text( + item["text"], item["lang"] + ) + if self.tar_file is not None: + self.tar_file.close() + if self.tar_io is not None: + self.tar_io.close() + + @abc.abstractmethod + def get_raw_data(self, root: str | io.TextIOWrapper) -> list[dict[str, str]]: ... + + @staticmethod + @abc.abstractmethod + def save_config( + save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" + ) -> None: ... + + def __len__(self): + return len(self.lazy_data) + + def __getitem__(self, n: int) -> DataType: + lazy_data = self.lazy_data[n] + if n in self.waveforms: + waveforms = self.waveforms[n] + text_input_ids = self.text_input_ids[n] + else: + waveforms = self.preprocess_audio(lazy_data["filepath"]) + text_input_ids = self.preprocess_text(lazy_data["text"], lazy_data["lang"]) + self.waveforms[n] = waveforms + self.text_input_ids[n] = text_input_ids + if len(self.waveforms) == len(self.lazy_data): + if self.tar_file is not None: + self.tar_file.close() + if self.tar_io is not None: + self.tar_io.close() + text_attention_mask = torch.ones( + len(text_input_ids), device=text_input_ids.device + ) + waveform_attention_mask = torch.ones(len(waveforms), device=waveforms.device) + return { + "filepath": lazy_data["filepath"], + "speaker": lazy_data["speaker"], + "lang": lazy_data["lang"], + "text": lazy_data["text"], + "text_input_ids": text_input_ids, + "text_attention_mask": text_attention_mask, + "waveforms": waveforms, + "waveform_attention_mask": waveform_attention_mask, + } + + def get_lazy_data( + self, + root: str | io.TextIOWrapper, + speakers: typing.Iterable[str] | None = None, + ) -> tuple[list[LazyDataType], set[str]]: + if speakers is not None: + new_speakers = set(speakers) + else: + new_speakers = set() + lazy_data = [] + + raw_data = self.get_raw_data(root) + folder_path = os.path.dirname(root) if isinstance(root, str) else "" + for item in raw_data: + if "speaker" not in item: + item["speaker"] = self.default_speaker + if "lang" not in item: + item["lang"] = self.default_lang + + if speakers is not None and item["speaker"] not in speakers: + continue + if speakers is None and item["speaker"] not in new_speakers: + new_speakers.add(item["speaker"]) + if self.tar_file is None and isinstance(root, str): + filepath = os.path.join(folder_path, item["filepath"]) + else: + filepath = item["filepath"] + lazy_data.append( + { + "filepath": filepath, + "speaker": item["speaker"], + "lang": item["lang"].lower(), + "text": item["text"], + } + ) + return lazy_data, new_speakers + + def preprocess_text( + self, + text: str, + lang: str, + do_text_normalization: bool = True, + do_homophone_replacement: bool = True, + ) -> torch.Tensor: + + text = self.normalizer( + text, + do_text_normalization, + do_homophone_replacement, + lang, + ) + + text = f"[Stts][spk_emb]{text}[Ptts]" + # text = f'[Stts][empty_spk]{text}[Ptts]' + + text_token = self.tokenizer(text, return_tensors="pt", add_special_tokens=False) + return text_token["input_ids"].squeeze(0) + + def preprocess_audio(self, filepath: str) -> torch.Tensor: + if self.tar_file is not None: + file = self.tar_file.extractfile(filepath) + waveforms, sample_rate = torchaudio.load(file) + else: + waveforms, sample_rate = torchaudio.load(filepath) + if sample_rate != self.sample_rate: + waveforms = torchaudio.functional.resample( + waveforms, + orig_freq=sample_rate, + new_freq=self.sample_rate, + ) + # (channel, time) + return waveforms.mean(0) # (time,) + + +class JsonFolder(AudioFolder): + """ + In json file, each item is formatted as following example: + `{"filepath": "path/to/file.wav", "speaker": "John", "lang": "ZH", "text": "Hello"}`. + + filepath is relative to the dirname of root json file. + """ + + def get_raw_data(self, root: str | io.TextIOWrapper) -> list[dict[str, str]]: + root = open(root, "r", encoding="utf-8") if isinstance(root, str) else root + raw_data = json.load(root) + root.close() + return raw_data + + @staticmethod + def save_config( + save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" + ) -> None: + save_data = [item.copy() for item in lazy_data] + for item in save_data: + item["filepath"] = os.path.relpath(item["filepath"], rel_path) + with open(save_path, "w", encoding="utf-8") as f: + json.dump(save_data, f, ensure_ascii=False, indent=4) + + +class ListFolder(AudioFolder): + """ + In list file, each row is formatted as `filepath|speaker|lang|text` with `|` as separator. + `path/to/file.wav|John|ZH|Hello`. + + filepath is relative to the dirname of root list file. + """ + + def get_raw_data(self, root: str | io.TextIOWrapper) -> list[dict[str, str]]: + raw_data = [] + root = open(root, "r", encoding="utf-8") if isinstance(root, str) else root + for line in root.readlines(): + line = line.strip().removesuffix("\n") + if len(line) == 0: + continue + filepath, speaker, lang, text = line.split(sep="|", maxsplit=3) + raw_data.append( + { + "text": text, + "filepath": filepath, + "speaker": speaker, + "lang": lang, + } + ) + root.close() + return raw_data + + @staticmethod + def save_config( + save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" + ) -> None: + save_data = [item.copy() for item in lazy_data] + for item in save_data: + item["filepath"] = os.path.relpath(item["filepath"], rel_path) + with open(save_path, "w", encoding="utf-8") as f: + for item in save_data: + f.write( + f"{item['filepath']}|{item['speaker']}|{item['lang']}|{item['text']}\n" + ) + + +class XzListTar(ListFolder): + """ + from torchvision.datasets.utils import download_url + download_url('https://drive.google.com/file/d/1vv73kAHiKb4KiL_oIH4DOWzUoaTeKzt_', './', 'Xz.tar', md5='47683c253d10250d9c32c964118c2b7c') + """ # noqa: E501 + + url = "https://drive.google.com/file/d/1vv73kAHiKb4KiL_oIH4DOWzUoaTeKzt_" + md5 = "47683c253d10250d9c32c964118c2b7c" + + def __init__( + self, + *args, + root: str | io.TextIOWrapper, + tar_path: str | None = None, + **kwargs: typing.Unpack[XzListTarKwargsType], + ): + if isinstance(root, str): + # make sure root is a list file + if not root.endswith(".list"): # folder case + if os.path.isfile(root): + raise FileExistsError(f"{root} is a file!") + elif not os.path.exists(root): + os.makedirs(root) + root = os.path.join(root, "all.list") + # make sure tar_path is a tar file + if tar_path is None: + dirname = os.path.dirname(root) + assert dirname + tar_path = os.path.join(dirname, "Xz.tar") + elif not tar_path.endswith(".tar"): # folder case + if os.path.isfile(tar_path): + raise FileExistsError(f"{tar_path} is a file!") + elif not os.path.exists(tar_path): + os.makedirs(tar_path) + tar_path = os.path.join(tar_path, "Xz.tar") + else: + assert tar_path is not None + # download tar file if not exists + if not os.path.isfile(tar_path): + dirname, basename = os.path.split(tar_path) + if not os.path.isdir(dirname): + os.makedirs(dirname) + download_url(self.url, dirname, basename, md5=self.md5) + self.prepare_all_list() # prepare all.list + if isinstance(root, str) and not os.path.isfile(root): + # if root is all.list, make sure it is prepared + if not root.endswith("all.list"): + with tarfile.open(tar_path) as tar_file: + root_str = tar_file.extractfile(root).read().decode("utf-8") + root = io.StringIO(root_str) + else: + self.prepare_all_list(tar_path=tar_path) # prepare all.list + + super().__init__(root, *args, tar_path=tar_path, **kwargs) + + @staticmethod + def prepare_all_list( + tar_path: str, + save_folder: str | None = None, + langs: list[str] = ["zh", "en"], + ) -> None: + if save_folder is None: + save_folder = os.path.dirname(tar_path) + if os.path.isfile(save_folder): + raise FileExistsError(f"{save_folder} already exists as a file!") + elif not os.path.exists(save_folder): + os.makedirs(save_folder) + lazy_data = [] + + with tarfile.open(tar_path) as tar_file: + for member in tar_file.getmembers(): + if not member.isfile(): + continue + if member.name.endswith(".list"): + print(member.name) + root_io = io.TextIOWrapper(tar_file.extractfile(member)) + lazy_data += ListFolder(root_io).lazy_data + if member.name.endswith(".json"): + print(member.name) + root_io = io.TextIOWrapper(tar_file.extractfile(member)) + lazy_data += JsonFolder(root_io).lazy_data + if langs is not None: + lazy_data = [item for item in lazy_data if item["lang"] in langs] + ListFolder.save_config(os.path.join(save_folder, "all.list"), lazy_data) + JsonFolder.save_config(os.path.join(save_folder, "all.json"), lazy_data) + print(f"all.list and all.json are saved to {save_folder}") + + +class XzListFolder(ListFolder): + """ + [Xz乔希](https://space.bilibili.com/5859321) + + Only look at the basename of filepath in list file. Previous folder paths are ignored. + Files are organized as `[list basename]/[file basename]` + + Example tree structure: + + [folder] + ├── speaker_A + │ ├── 1.wav + │ └── 2.wav + ├── speaker_A.list + ├── speaker_B + │ ├── 1.wav + │ └── 2.wav + └── speaker_B.list + """ + + def get_raw_data(self, root: str) -> list[dict[str, str]]: + raw_data = super().get_raw_data(root) + for item in raw_data: + item["filepath"] = os.path.join( + os.path.basename(root).removesuffix(".list"), + os.path.basename(item["filepath"]), + ) + return raw_data + + +class AudioCollator: + def __init__(self, text_pad: int = 0, audio_pad: int = 0): + self.text_pad = text_pad + self.audio_pad = audio_pad + + def __call__(self, batch: list[DataType]): + batch = [x for x in batch if x is not None] + + audio_maxlen = max(len(item["waveforms"]) for item in batch) + text_maxlen = max(len(item["text_input_ids"]) for item in batch) + + filepath = [] + speaker = [] + lang = [] + text = [] + text_input_ids = [] + text_attention_mask = [] + waveforms = [] + waveform_attention_mask = [] + + for x in batch: + filepath.append(x["filepath"]) + speaker.append(x["speaker"]) + lang.append(x["lang"]) + text.append(x["text"]) + text_input_ids.append( + torch.nn.functional.pad( + x["text_input_ids"], + (text_maxlen - len(x["text_attention_mask"]), 0), + value=self.text_pad, + ) + ) + text_attention_mask.append( + torch.nn.functional.pad( + x["text_attention_mask"], + (text_maxlen - len(x["text_attention_mask"]), 0), + value=0, + ) + ) + waveforms.append( + torch.nn.functional.pad( + x["waveforms"], + (0, audio_maxlen - len(x["waveform_attention_mask"])), + value=self.audio_pad, + ) + ) + waveform_attention_mask.append( + torch.nn.functional.pad( + x["waveform_attention_mask"], + (0, audio_maxlen - len(x["waveform_attention_mask"])), + value=0, + ) + ) + return { + "filepath": filepath, + "speaker": speaker, + "lang": lang, + "text": text, + "text_input_ids": torch.stack(text_input_ids), + "text_attention_mask": torch.stack(text_attention_mask), + "waveforms": torch.stack(waveforms), + "waveform_attention_mask": torch.stack(waveform_attention_mask), + } + + +def formalize_xz_list(src_folder: str): + for root, _, files in os.walk(src_folder): + for file in files: + if file.endswith(".list"): + filepath = os.path.join(root, file) + print(filepath) + lazy_data = XzListFolder(filepath).lazy_data + XzListFolder.save_config(filepath, lazy_data, rel_path=src_folder) + + +def prepare_all_list( + src_folder: str, save_folder: str | None = None, langs: list[str] = ["zh", "en"] +) -> None: + if save_folder is None: + save_folder = src_folder + if os.path.isfile(save_folder): + raise FileExistsError(f"{save_folder} already exists as a file!") + elif not os.path.exists(save_folder): + os.makedirs(save_folder) + lazy_data = [] + same_folder = os.path.samefile(src_folder, save_folder) + for root, _, files in os.walk(src_folder): + for file in files: + filepath = os.path.join(root, file) + if same_folder and file in ("all.list", "all.json"): + continue + if file.endswith(".list"): + print(filepath) + lazy_data += ListFolder(filepath).lazy_data + elif file.endswith(".json"): + print(filepath) + lazy_data += JsonFolder(filepath).lazy_data + if langs is not None: + lazy_data = [item for item in lazy_data if item["lang"] in langs] + ListFolder.save_config( + os.path.join(save_folder, "all.list"), lazy_data, rel_path=save_folder + ) + JsonFolder.save_config( + os.path.join(save_folder, "all.json"), lazy_data, rel_path=save_folder + ) + print(f"all.list and all.json are saved to {save_folder}") diff --git a/ChatTTS/train/model.py b/ChatTTS/train/model.py new file mode 100644 index 000000000..98b4285a0 --- /dev/null +++ b/ChatTTS/train/model.py @@ -0,0 +1,357 @@ +from enum import StrEnum + +import torch +import torch.nn.functional +import torch.utils.data + +import ChatTTS +from ChatTTS.utils.ansi import ansi, get_ansi_len, output_iter +from ChatTTS.utils.log import MetricLogger + +from .dataset import AudioFolder, AudioCollator +from .utils import ( + get_mel_specs, + get_mel_attention_mask, + get_dvae_mel_specs, + get_hidden_states_and_labels, +) + + +class TrainModule(StrEnum): + GPT_ALL = "gpt_all" # GPT + SPEAKER + DECODER + GPT_SPEAKER = "gpt_speaker" # GPT + SPEAKER + + GPT = "gpt" + DECODER = "decoder" + SPEAKER = "speaker" + + DVAE = "dvae" + DVAE_ENCODER = "dvae_encoder" + DVAE_DECODER = "dvae_decoder" + + +def train_autoencoder( + chat: ChatTTS.Chat, + dataset: AudioFolder, + train_module: TrainModule = TrainModule.DVAE, + batch_size: int = 10, + epochs: int = 10, + lr: float = 1e-3, + grad_norm_clip: float = 1.0, + validate: bool = False, +): + chat.dvae.eval().requires_grad_(False) + if not validate: + match train_module: + case TrainModule.DVAE: + train_params = list(chat.dvae.parameters()) + case TrainModule.DVAE_ENCODER: + train_params = [] + train_params += list(chat.dvae.downsample_conv.parameters()) + train_params += list(chat.dvae.encoder.parameters()) + train_params += list(chat.dvae.vq_layer.parameters()) + case TrainModule.DVAE_DECODER: + train_params = [] + train_params += list(chat.dvae.decoder.parameters()) + train_params += list(chat.dvae.out_conv.parameters()) + optimizer = torch.optim.AdamW(train_params, lr=lr, betas=[0.8, 0.99], eps=1e-6) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, epochs, 1e-7 + ) + # lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999999) + + def activate_params(): + match train_module: + case TrainModule.DVAE: + chat.dvae.train().requires_grad_() + case TrainModule.DVAE_ENCODER: + chat.dvae.downsample_conv.train().requires_grad_() + chat.dvae.encoder.train().requires_grad_() + chat.dvae.vq_layer.train().requires_grad_() + case TrainModule.DVAE_DECODER: + chat.dvae.decoder.train().requires_grad_() + chat.dvae.out_conv.train().requires_grad_() + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=AudioCollator(), + # num_workers=4, + ) + logger = MetricLogger() + logger.create_meters(loss=None) + if not validate: + train_autoencoder(chat=chat, dataset=dataset, validate=True) + for _epoch in range(1 if validate else epochs): + if not validate: + activate_params() + _epoch += 1 + logger.reset() + if validate: + header: str = "{blue_light}{0}{reset}".format("AutoEncoder", **ansi) + header = header.ljust(max(len("AutoEncoder"), 30) + get_ansi_len(header)) + else: + header: str = "{blue_light}{0}: {1}{reset}".format( + "Epoch", output_iter(_epoch, epochs), **ansi + ) + header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header)) + iterator = logger.log_every(loader, header=header, tqdm_header="Batch") + for batch in iterator: + waveforms: torch.Tensor = batch["waveforms"] # (batch_size, time) + waveform_attention_mask: torch.Tensor = batch[ + "waveform_attention_mask" + ] # (batch_size, time) + + waveforms = waveforms.to(chat.device, non_blocking=True) + waveform_attention_mask = waveform_attention_mask.to( + chat.device, non_blocking=True + ) + + mel_specs = get_mel_specs(chat, waveforms) # (batch_size, 100, mel_len) + mel_attention_mask = get_mel_attention_mask( + waveform_attention_mask, mel_len=mel_specs.size(2) + ) # (batch_size, mel_len) + mel_specs = mel_specs * mel_attention_mask.unsqueeze(1) # clip + + dvae_mel_specs = get_dvae_mel_specs( + chat, mel_specs, mel_attention_mask + ) # (batch_size, 100, mel_len) + dvae_mel_specs = dvae_mel_specs * mel_attention_mask.unsqueeze(1) # clip + + loss = torch.nn.functional.mse_loss(dvae_mel_specs, mel_specs) + logger.meters["loss"].update(loss.item(), n=len(waveform_attention_mask)) + if not validate: + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(train_params, grad_norm_clip) + optimizer.step() + if not validate: + lr_scheduler.step() + train_autoencoder(chat=chat, dataset=dataset, validate=True) + if not validate: + optimizer.zero_grad() + + +def train_gpt( + chat: ChatTTS.Chat, + dataset: AudioFolder, + train_module: TrainModule = TrainModule.GPT_ALL, + batch_size: int = 10, + epochs: int = 10, + grad_norm_clip: float = 1.0, + speaker_embeds: dict[str, torch.Tensor] = {}, + train_text: bool = False, + validate: bool = False, +) -> dict[str, torch.Tensor]: + for speaker in dataset.speakers: + if speaker not in speaker_embeds: + speaker_embeds[speaker] = chat.speaker._sample_random().to( + device=chat.device + ) + + chat.dvae.eval().requires_grad_(False) + chat.gpt.eval().requires_grad_(False) + chat.decoder.eval().requires_grad_(False) + + if not validate: + train_speaker = train_module in [ + TrainModule.GPT_ALL, + TrainModule.GPT_SPEAKER, + TrainModule.SPEAKER, + ] + match train_module: + case TrainModule.GPT_ALL: + train_params = [] + train_params += list(speaker_embeds.values()) + train_params += list(chat.gpt.parameters()) + train_params += list(chat.decoder.parameters()) + optimizer = torch.optim.Adam( + chat.gpt.parameters(), lr=1e-5, weight_decay=0, betas=[0.9, 0.95] + ) + optimizer.add_param_group( + { + "params": chat.decoder.parameters(), + "lr": 1e-5, + "weight_decay": 0, + "betas": [0.9, 0.95], + } + ) + optimizer.add_param_group( + { + "params": speaker_embeds.values(), + "lr": 1e-2, + "weight_decay": 0, + "betas": [0.9, 0.95], + } + ) + case TrainModule.GPT_SPEAKER: + train_params = [] + train_params += list(speaker_embeds.values()) + train_params += list(chat.gpt.parameters()) + optimizer = torch.optim.Adam( + chat.gpt.parameters(), lr=1e-5, weight_decay=0, betas=[0.9, 0.95] + ) + optimizer.add_param_group( + { + "params": speaker_embeds.values(), + "lr": 1e-2, + "weight_decay": 0, + "betas": [0.9, 0.95], + } + ) + case TrainModule.GPT: + train_params = list(chat.gpt.parameters()) + case TrainModule.DECODER: + train_params = list(chat.decoder.parameters()) + case TrainModule.SPEAKER: + train_params = list(speaker_embeds.values()) + optimizer = torch.optim.Adam( + train_params, lr=1e-2, weight_decay=0, betas=[0.9, 0.95] + ) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, epochs, 1e-7 + ) + # lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=functools.partial()) + + def activate_params(): + if train_speaker: + for speaker_embed in speaker_embeds.values(): + speaker_embed.requires_grad_(True) + match train_module: + case TrainModule.GPT_ALL: + chat.gpt.train().requires_grad_() + chat.decoder.train().requires_grad_() + case TrainModule.GPT_SPEAKER: + chat.gpt.train().requires_grad_() + case TrainModule.GPT: + chat.gpt.train().requires_grad_() + case TrainModule.DECODER: + chat.decoder.train().requires_grad_() + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=AudioCollator(), + # num_workers=4, + ) + logger = MetricLogger() + logger.create_meters(audio_loss=None, mse_loss=None) + if validate or train_text: + logger.create_meters(text_loss=None) + if not validate: + train_gpt( + chat=chat, dataset=dataset, speaker_embeds=speaker_embeds, validate=True + ) + for _epoch in range(1 if validate else epochs): + if not validate: + activate_params() + _epoch += 1 + logger.reset() + if validate: + header: str = "{blue_light}{0}{reset}".format("GPT", **ansi) + header = header.ljust(max(len("GPT"), 30) + get_ansi_len(header)) + else: + header: str = "{blue_light}{0}: {1}{reset}".format( + "Epoch", output_iter(_epoch, epochs), **ansi + ) + header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header)) + iterator = logger.log_every(loader, header=header, tqdm_header="Batch") + for batch in iterator: + speakers: list[str] = batch["speaker"] # (batch_size,) + text_input_ids: torch.Tensor = batch[ + "text_input_ids" + ] # (batch_size, text_len) + text_attention_mask: torch.Tensor = batch[ + "text_attention_mask" + ] # (batch_size, text_len) + waveforms: torch.Tensor = batch["waveforms"] # (batch_size, time) + waveform_attention_mask: torch.Tensor = batch[ + "waveform_attention_mask" + ] # (batch_size, time) + + text_input_ids = text_input_ids.to(chat.device, non_blocking=True) + text_attention_mask = text_attention_mask.to(chat.device, non_blocking=True) + waveforms = waveforms.to(chat.device, non_blocking=True) + waveform_attention_mask = waveform_attention_mask.to( + chat.device, non_blocking=True + ) + + mel_specs = get_mel_specs(chat, waveforms) # (batch_size, 100, mel_len) + mel_attention_mask = get_mel_attention_mask( + waveform_attention_mask, mel_len=mel_specs.size(2) + ) # (batch_size, mel_len) + mel_specs = mel_specs * mel_attention_mask.unsqueeze(1) # clip + + results = get_hidden_states_and_labels( + chat=chat, + mel_specs=mel_specs, + mel_attention_mask=mel_attention_mask, + text_input_ids=text_input_ids, + text_attention_mask=text_attention_mask, + speakers=speakers, + speaker_embeds=speaker_embeds, + ) + hidden_states = results["hidden_states"] + labels = results["labels"] + + text_len = text_input_ids.size(1) + audio_hidden_states = hidden_states[ + :, text_len - 1 : -1 + ] # (batch_size, mel_len+1, 768) + audio_labels = labels[:, text_len:] # (batch_size, mel_len+1) + + audio_logits = torch.stack( + [ + chat.gpt.head_code[i](audio_hidden_states) + for i in range(chat.gpt.num_vq) + ], + dim=2, + ) # (batch_size, mel_len+1, num_vq, num_class_audio) + audio_loss: torch.Tensor = torch.nn.functional.cross_entropy( + audio_logits.flatten(0, 2), audio_labels.flatten(0, 2) + ) + loss: torch.Tensor = audio_loss + if validate or train_text: + text_hidden_states = hidden_states[ + :, : text_len - 1 + ] # (batch_size, text_len-1, 768) + text_labels = labels[:, 1:text_len, 0] # (batch_size, text_len-1) + + text_logits: torch.Tensor = chat.gpt.head_text( + text_hidden_states + ) # (batch_size, text_len-1, num_class_text) + text_loss: torch.Tensor = torch.nn.functional.cross_entropy( + text_logits.flatten(0, 1), text_labels.flatten(0, 1) + ) + loss = loss + text_loss + logger.meters["text_loss"].update(text_loss.item(), n=batch_size) + + decoder_mel_specs = chat.decoder( + audio_hidden_states[:, :-1].transpose(1, 2) + ) + decoder_mel_specs = decoder_mel_specs * mel_attention_mask.unsqueeze( + 1 + ) # clip + mse_loss = torch.nn.functional.mse_loss( + decoder_mel_specs, + mel_specs, + ) + loss = loss + 10 * mse_loss + logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size) + logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size) + + if not validate: + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(train_params, grad_norm_clip) + optimizer.step() + if not validate: + lr_scheduler.step() + train_gpt( + chat=chat, dataset=dataset, speaker_embeds=speaker_embeds, validate=True + ) + if not validate: + optimizer.zero_grad() + return speaker_embeds diff --git a/ChatTTS/train/utils.py b/ChatTTS/train/utils.py new file mode 100644 index 000000000..4d38afb12 --- /dev/null +++ b/ChatTTS/train/utils.py @@ -0,0 +1,231 @@ +from typing import Literal + +import torch +from einops import rearrange +from transformers.trainer_pt_utils import LabelSmoother +from vector_quantize_pytorch.residual_fsq import GroupedResidualFSQ + +import ChatTTS +from ChatTTS.model.dvae import DVAE + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index +AUDIO_EOS_TOKEN_ID: int = 0 +AUDIO_PAD_TOKEN_ID: int = AUDIO_EOS_TOKEN_ID +# SPEAKER_TOKEN_ID: int = chat.tokenizer.spk_emb_ids +# AUDIO_EOS_TOKEN_ID: int = tokenizer.convert_tokens_to_ids('[Etts]') + + +def get_mel_specs( + chat: ChatTTS.Chat, + waveforms: torch.Tensor, # (batch_size, time) +) -> torch.Tensor: + mel_specs = chat.dvae.preprocessor_mel(waveforms) # (batch_size, 100, mel_len) + if mel_specs.size(2) % 2 != 0: + mel_specs = torch.cat( + [mel_specs, torch.zeros_like(mel_specs[:, :, :1])], + dim=2, + ) + return mel_specs # (batch_size, 100, mel_len) + + +def get_dvae_mel_specs( + chat: ChatTTS.Chat, + mel_specs: torch.Tensor, # (batch_size, 100, mel_len) + mel_attention_mask: torch.Tensor, # (batch_size, mel_len) +): + audio_attention_mask = get_audio_attention_mask( + mel_attention_mask + ) # (batch_size, mel_len / 2) + audio_latents = dvae_encode( + chat.dvae, mel_specs + ) # (batch_size, audio_dim, mel_len / 2) + audio_latents = audio_latents * audio_attention_mask.unsqueeze(1) # clip + audio_quantized_latents, _ = dvae_quantize( + chat.dvae.vq_layer.quantizer, audio_latents + ) # (batch_size, audio_dim, mel_len / 2) + audio_quantized_latents = audio_quantized_latents * audio_attention_mask.unsqueeze( + 1 + ) + dvae_mel_specs = dvae_decode( + chat.dvae, audio_quantized_latents + ) # (batch_size, 100, mel_len) + return dvae_mel_specs # (batch_size, 100, mel_len) + + +def get_mel_attention_mask( + waveform_attention_mask: torch.Tensor, # (batch_size, time) + mel_len: int, +): + batch_size = waveform_attention_mask.size(0) + mel_attention_mask = torch.ones( + (batch_size, mel_len), + device=waveform_attention_mask.device, + ) + indices = waveform_attention_mask.int().sum(dim=1) # (batch_size,) + indices = torch.ceil(indices * mel_len / waveform_attention_mask.size(1)).int() + for i in range(batch_size): + mel_attention_mask[i, indices[i] :] = 0 + return mel_attention_mask # (batch_size, mel_len) + + +def get_audio_attention_mask( + mel_attention_mask: torch.Tensor, # (batch_size, mel_len) +): + audio_attention_mask = mel_attention_mask[:, ::2] # (batch_size, mel_len / 2) + return audio_attention_mask # (batch_size, mel_len / 2) + + +def dvae_encode( + dvae: DVAE, + mel_specs: torch.Tensor, # (batch_size, 100, mel_len) +) -> torch.Tensor: + x: torch.Tensor = dvae.downsample_conv(mel_specs / dvae.coef) + x = dvae.encoder(x) + return x # (batch_size, audio_dim, mel_len / 2) + + +def dvae_quantize( + quantizer: GroupedResidualFSQ, + audio_latents: torch.Tensor, # (batch_size, audio_dim=1024, mel_len / 2) +) -> tuple[torch.Tensor, torch.Tensor]: + # feat shape (batch_size, mel_len / 2, audio_dim) + # ind shape (GFSQ.G, batch_size, mel_len / 2, GFSQ.R) + # num_vq=GFSQ.G*GFSQ.R + feat, ind = quantizer(audio_latents.transpose(1, 2)) + audio_quantized_latents = feat.transpose( + 1, 2 + ) # (batch_size, audio_dim, mel_len / 2) + audio_input_ids = rearrange( + ind, "g b t r ->b t (g r)" + ) # (batch_size, mel_len / 2, num_vq) + return audio_quantized_latents, audio_input_ids + + +def dvae_decode( + dvae: DVAE, + audio_latents: torch.Tensor, # (batch_size, audio_dim, mel_len / 2) +) -> torch.Tensor: + assert audio_latents.size(1) % 2 == 0 + reshaped_audio_latents = ( + audio_latents.view( + ( + audio_latents.size(0), + 2, + audio_latents.size(1) // 2, + audio_latents.size(2), + ), + ) + .permute(0, 2, 3, 1) + .flatten(2) + ) # (batch_size, audio_dim / 2, mel_len) + x: torch.Tensor = dvae.decoder(reshaped_audio_latents) + x = dvae.out_conv(x) + return x * dvae.coef # (batch_size, 100, mel_len) + + +# TODO: a better name +def get_hidden_states_and_labels( + chat: ChatTTS.Chat, + mel_specs: torch.Tensor, # (batch_size, 100, mel_len) + mel_attention_mask: torch.Tensor, # (batch_size, mel_len) + text_input_ids: torch.Tensor, # (batch_size, text_len) + text_attention_mask: torch.Tensor, # (batch_size, text_len) + speakers: list[str], + speaker_embeds: dict[str, torch.Tensor], +) -> dict[Literal["labels"] | Literal["hidden_states"], torch.Tensor]: + audio_attention_mask = get_audio_attention_mask( + mel_attention_mask + ) # (batch_size, mel_len / 2) + audio_latents = dvae_encode( + chat.dvae, mel_specs + ) # (batch_size, audio_dim, mel_len // 2) + audio_latents = audio_latents * audio_attention_mask.unsqueeze(1) # clip + _, dvae_audio_input_ids = dvae_quantize( + chat.dvae.vq_layer.quantizer, audio_latents + ) # (batch_size, mel_len // 2) + dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID + + batch_size = text_attention_mask.size(0) + # add audio eos token + extended_audio_attention_mask = torch.cat( + [ + audio_attention_mask, + torch.zeros( + (batch_size, 1), + dtype=audio_attention_mask.dtype, + device=audio_attention_mask.device, + ), + ], + dim=1, + ) # (batch_size, mel_len+1) + extended_audio_input_ids = torch.cat( + [ + dvae_audio_input_ids, + AUDIO_PAD_TOKEN_ID + * torch.ones( + (batch_size, 1, chat.gpt.num_vq), + dtype=dvae_audio_input_ids.dtype, + device=dvae_audio_input_ids.device, + ), + ], + dim=1, + ) # (batch_size, mel_len+1, num_vq) + indices = audio_attention_mask.int().sum(dim=1) # (batch_size,) + for i in range(batch_size): + extended_audio_attention_mask[i, indices[i]] = 1 + extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID + + # combine text and audio + input_ids = torch.cat( # (batch_size, text_len + mel_len + 1, num_vq) + [ + text_input_ids.unsqueeze(-1).repeat( + 1, 1, chat.gpt.num_vq + ), # (batch_size, text_len, num_vq) + extended_audio_input_ids, # (batch_size, mel_len, num_vq) + ], + dim=1, + ) + attention_mask = torch.cat( # (batch_size, text_len + mel_len + 1) + [text_attention_mask, extended_audio_attention_mask], + dim=1, + ) + text_mask = torch.cat( # (batch_size, text_len + mel_len + 1) + [ + torch.ones_like(text_attention_mask, dtype=bool), + torch.zeros_like(extended_audio_attention_mask, dtype=bool), + ], + dim=1, + ) + # set labels + labels = input_ids.clone() # (batch_size, text_len + mel_len + 1, num_vq) + labels[~attention_mask.bool()] = IGNORE_TOKEN_ID + + # (batch_size, text_len + mel_len, 768) + inputs_embeds = chat.embed(input_ids, text_mask).clone() + + for i, speaker in enumerate(speakers): + inputs_embeds[i] = chat.speaker.apply( + emb=inputs_embeds[i].unsqueeze(0), + spk_emb=speaker_embeds[speaker], + input_ids=text_input_ids[i].unsqueeze(0), + spk_emb_ids=chat.tokenizer.spk_emb_ids, + device=chat.device, + ).squeeze(0) + # indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1) + # for i, speaker in enumerate(speakers): + # inputs_embeds[i, indices[i]] = torch.nn.functional.normalize( + # speaker_embeds[speaker].to(dtype=inputs_embeds.dtype), + # p=2.0, + # dim=-1, + # eps=1e-12, + # ).unsqueeze(0) + + # (batch_size, text_len + mel_len) + outputs = chat.gpt.gpt.forward( + inputs_embeds=inputs_embeds, attention_mask=attention_mask + ) + hidden_states = ( + outputs.last_hidden_state + ) # (batch_size, text_len + mel_len + 1, 768) + + return {"labels": labels, "hidden_states": hidden_states} diff --git a/ChatTTS/utils/ansi.py b/ChatTTS/utils/ansi.py new file mode 100644 index 000000000..e5805fe5f --- /dev/null +++ b/ChatTTS/utils/ansi.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 + +import re +import sys +from contextlib import contextmanager + + +class ANSI: + ansi_color = { + "black": "\033[30m", + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "purple": "\033[35m", + "blue_light": "\033[36m", + "white": "\033[37m", + "reset": "\033[0m", + "upline": "\033[1A", + "clear_line": "\033[2K", + "clear": "\033[2J", + } + ansi_nocolor = { + "black": "", + "red": "", + "green": "", + "yellow": "", + "blue": "", + "purple": "", + "blue_light": "", + "white": "", + "reset": "", + "upline": "\033[1A\033[", + "clear_line": "\033[K", + "clear": "\033[2J", + } + + def __init__(self): + self._dict = ANSI.ansi_color if ("--color" in sys.argv) else ANSI.ansi_nocolor + + def switch(self, color: bool): + self._dict = ANSI.ansi_color if color else ANSI.ansi_nocolor + + def keys(self): + return self._dict.keys() + + def items(self): + return self._dict.items() + + def __getitem__(self, key): + return self._dict[key] + + def __str__(self): + return str(self._dict) + + def __repr__(self): + return repr(self._dict) + + +ansi = ANSI() + + +def remove_ansi(s: str) -> str: + ansi_escape = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]") + return ansi_escape.sub("", s) + + +def get_ansi_len(s: str) -> int: + return len(s) - len(remove_ansi(s)) + + +def prints(*args: str, indent: int = 0, prefix: str = "", **kwargs): + assert indent >= 0 + new_args = [] + for arg in args: + new_args.append(indent_str(str(arg), indent=indent)) + if len(new_args): + new_args[0] = prefix + str(new_args[0]) + print(*new_args, **kwargs) + + +def output_iter(_iter: int, iteration: int = None, iter_len: int = 4) -> str: + if iteration is None: + pattern = "{blue_light}[ {red}{0}{blue_light} ]{reset}" + return pattern.format(str(_iter).rjust(iter_len), **ansi) + else: + iter_str = str(iteration) + length = len(iter_str) + pattern = ( + "{blue_light}[ {red}{0}{blue_light} " "/ {red}{1}{blue_light} ]{reset}" + ) + return pattern.format(str(_iter).rjust(length), iter_str, **ansi) + + +def indent_str(s_: str, indent: int = 0) -> str: + # modified from torch.nn.modules._addindent + if indent > 0 and s_: + s_ = indent * " " + str(s_[:-1]).replace("\n", "\n" + indent * " ") + s_[-1] + return s_ + + +class IndentRedirect: # TODO: inherit TextIOWrapper? + def __init__(self, buffer: bool = True, indent: int = 0): + self.__console__ = sys.stdout + self.indent = indent + self.__buffer: str = None + if buffer: + self.__buffer = "" + + def write(self, text: str, indent: int = None): + indent = indent if indent is not None else self.indent + text = indent_str(text, indent=indent) + if self.__buffer is None: + self.__console__.write(text) + else: + self.__buffer += text + + def flush(self): + if self.__buffer is not None: + self.__console__.write(self.__buffer) + self.__buffer = "" + self.__console__.flush() + + @contextmanager + def __call__(self): + try: + sys.stdout = self + yield + finally: + sys.stdout = self.__console__ + self.__buffer = "" + + def enable(self): + sys.stdout = self + + def disable(self): + if self.__buffer is not None: + self.__buffer = "" + sys.stdout = self.__console__ + + @property + def buffer(self) -> str: + return self.__buffer + + +redirect = IndentRedirect() diff --git a/ChatTTS/utils/log.py b/ChatTTS/utils/log.py index 1fd9b93d3..c0c83aa29 100644 --- a/ChatTTS/utils/log.py +++ b/ChatTTS/utils/log.py @@ -1,6 +1,422 @@ +#!/usr/bin/env python3 + import logging from pathlib import Path +import statistics +import time +from collections import defaultdict, deque +from tqdm import tqdm as tqdm_class + +from typing import Generator, Iterable, TypeVar + +import torch +import torch.distributed as dist + +from .ansi import ansi, prints, get_ansi_len + +__all__ = ["SmoothedValue", "MetricLogger"] + +MB = 1 << 20 +T = TypeVar("T") + + +class SmoothedValue: + r"""Track a series of values and provide access to smoothed values over a + window or the global series average. + + See Also: + https://github.com/pytorch/vision/blob/main/references/classification/utils.py + + Args: + name (str): Name string. + window_size (int): The :attr:`maxlen` of :class:`~collections.deque`. + fmt (str): The format pattern of ``str(self)``. + + Attributes: + name (str): Name string. + fmt (str): The string pattern. + deque (~collections.deque): The unique data series. + count (int): The amount of data. + total (float): The sum of all data. + + median (float): The median of :attr:`deque`. + avg (float): The avg of :attr:`deque`. + global_avg (float): :math:`\frac{\text{total}}{\text{count}}` + max (float): The max of :attr:`deque`. + min (float): The min of :attr:`deque`. + last_value (float): The last value of :attr:`deque`. + """ + + def __init__( + self, name: str = "", window_size: int = None, fmt: str = "{global_avg:.3f}" + ): + self.name = name + self.deque: deque[float] = deque(maxlen=window_size) + self.count: int = 0 + self.total: float = 0.0 + self.fmt = fmt + + def update(self, value: float, n: int = 1) -> 'SmoothedValue': + r"""Update :attr:`n` pieces of data with same :attr:`value`. + + .. code-block:: python + + self.deque.append(value) + self.total += value * n + self.count += n + + Args: + value (float): the value to update. + n (int): the number of data with same :attr:`value`. + + Returns: + SmoothedValue: return ``self`` for stream usage. + """ + self.deque.append(value) + self.total += value * n + self.count += n + return self + + def update_list(self, value_list: list[float]) -> 'SmoothedValue': + r"""Update :attr:`value_list`. + + .. code-block:: python + + for value in value_list: + self.deque.append(value) + self.total += value + self.count += len(value_list) + + Args: + value_list (list[float]): the value list to update. + + Returns: + SmoothedValue: return ``self`` for stream usage. + """ + for value in value_list: + self.deque.append(value) + self.total += value + self.count += len(value_list) + return self + + def reset(self) -> 'SmoothedValue': + r"""Reset ``deque``, ``count`` and ``total`` to be empty. + + Returns: + SmoothedValue: return ``self`` for stream usage. + """ + self.deque = deque(maxlen=self.deque.maxlen) + self.count = 0 + self.total = 0.0 + return self + + def synchronize_between_processes(self): + r""" + Warning: + Does NOT synchronize the deque! + """ + if not (dist.is_available() and dist.is_initialized()): + return + t = torch.tensor([self.count, self.total], + dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = float(t[1]) + + @property + def median(self) -> float: + try: + return statistics.median(self.deque) + except Exception: + return 0.0 + + @property + def avg(self) -> float: + try: + return statistics.mean(self.deque) + except Exception: + return 0.0 + + @property + def global_avg(self) -> float: + try: + return self.total / self.count + except Exception: + return 0.0 + + @property + def max(self) -> float: + try: + return max(self.deque) + except Exception: + return 0.0 + + @property + def min(self) -> float: + try: + return min(self.deque) + except Exception: + return 0.0 + + @property + def last_value(self) -> float: + try: + return self.deque[-1] + except Exception: + return 0.0 + + def __str__(self): + return self.fmt.format( + name=self.name, + count=self.count, + total=self.total, + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + min=self.min, + max=self.max, + last_value=self.last_value, + ) + + def __format__(self, format_spec: str) -> str: + return self.__str__() + + +class MetricLogger: + r""" + See Also: + https://github.com/pytorch/vision/blob/main/references/classification/utils.py + + Args: + delimiter (str): The delimiter to join different meter strings. + Defaults to ``''``. + meter_length (int): The minimum length for each meter. + Defaults to ``20``. + tqdm (bool): Whether to use tqdm to show iteration information. + Defaults to ``env['tqdm']``. + indent (int): The space indent for the entire string. + Defaults to ``0``. + + Attributes: + meters (dict[str, SmoothedValue]): The meter dict. + iter_time (SmoothedValue): Iteration time meter. + data_time (SmoothedValue): Data loading time meter. + memory (SmoothedValue): Memory usage meter. + """ + + def __init__( + self, + delimiter: str = "", + meter_length: int = 20, + tqdm: bool = True, + indent: int = 0, + **kwargs, + ): + self.meters: defaultdict[str, + SmoothedValue] = defaultdict(SmoothedValue) + self.create_meters(**kwargs) + self.delimiter = delimiter + self.meter_length = meter_length + self.tqdm = tqdm + self.indent = indent + + self.iter_time = SmoothedValue() + self.data_time = SmoothedValue() + self.memory = SmoothedValue(fmt="{max:.0f}") + + def create_meters(self, **kwargs: str) -> 'SmoothedValue': + r"""Create meters with specific ``fmt`` in :attr:`self.meters`. + + ``self.meters[meter_name] = SmoothedValue(fmt=fmt)`` + + Args: + **kwargs: ``(meter_name: fmt)`` + + Returns: + MetricLogger: return ``self`` for stream usage. + """ + for k, v in kwargs.items(): + self.meters[k] = SmoothedValue( + fmt="{global_avg:.3f}" if v is None else v) + return self + + def update(self, n: int = 1, **kwargs: float) -> 'SmoothedValue': + r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update()`. + + ``self.meters[meter_name].update(float(value), n=n)`` + + Args: + n (int): the number of data with same value. + **kwargs: ``{meter_name: value}``. + + Returns: + MetricLogger: return ``self`` for stream usage. + """ + for k, v in kwargs.items(): + if k not in self.meters: + self.meters[k] = SmoothedValue() + self.meters[k].update(float(v), n=n) + return self + + def update_list(self, **kwargs: list) -> 'SmoothedValue': + r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update_list()`. + + ``self.meters[meter_name].update_list(value_list)`` + + Args: + **kwargs: ``{meter_name: value_list}``. + + Returns: + MetricLogger: return ``self`` for stream usage. + """ + for k, v in kwargs.items(): + self.meters[k].update_list(v) + return self + + def reset(self) -> 'SmoothedValue': + r"""Reset meter in :attr:`self.meters` by calling :meth:`SmoothedValue.reset()`. + + Returns: + MetricLogger: return ``self`` for stream usage. + """ + for meter in self.meters.values(): + meter.reset() + return self + + def get_str(self, cut_too_long: bool = True, strip: bool = True, **kwargs) -> str: + r"""Generate formatted string based on keyword arguments. + + ``key: value`` with max length to be :attr:`self.meter_length`. + + Args: + cut_too_long (bool): Whether to cut too long values to first 5 characters. + Defaults to ``True``. + strip (bool): Whether to strip trailing whitespaces. + Defaults to ``True``. + **kwargs: Keyword arguments to generate string. + """ + str_list: list[str] = [] + for k, v in kwargs.items(): + v_str = str(v) + _str: str = "{green}{k}{reset}: {v}".format(k=k, v=v_str, **ansi) + max_length = self.meter_length + get_ansi_len(_str) + if cut_too_long: + _str = _str[:max_length] + str_list.append(_str.ljust(max_length)) + _str = self.delimiter.join(str_list) + if strip: + _str = _str.rstrip() + return _str + + def __getattr__(self, attr: str) -> float: + if attr in self.meters: + return self.meters[attr] + if attr in vars(self): # TODO: use hasattr + return vars(self)[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format( + type(self).__name__, attr) + ) + + def __str__(self) -> str: + return self.get_str(**self.meters) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def log_every( + self, + iterable: Iterable[T], + header: str = "", + tqdm: bool = None, + tqdm_header: str = "Iter", + indent: int = None, + verbose: int = 1, + ) -> Generator[T, None, None]: + r"""Wrap an :class:`collections.abc.Iterable` with formatted outputs. + + * Middle Output: + ``{tqdm_header}: [ current / total ] str(self) {memory} {iter_time} {data_time} {time}<{remaining}`` + * Final Output + ``{header} str(self) {memory} {iter_time} {data_time} {total_time}`` + + Args: + iterable (~collections.abc.Iterable): The raw iterator. + header (str): The header string for final output. + Defaults to ``''``. + tqdm (bool): Whether to use tqdm to show iteration information. + Defaults to ``self.tqdm``. + tqdm_header (str): The header string for middle output. + Defaults to ``'Iter'``. + indent (int): The space indent for the entire string. + if ``None``, use ``self.indent``. + Defaults to ``None``. + verbose (int): The verbose level of output information. + """ + tqdm = tqdm if tqdm is not None else self.tqdm + indent = indent if indent is not None else self.indent + iterator = iterable + if len(header) != 0: + header = header.ljust(30 + get_ansi_len(header)) + if tqdm: + length = len(str(len(iterable))) + pattern: str = ( + "{tqdm_header}: {blue_light}" + "[ {red}{{n_fmt:>{length}}}{blue_light} " + "/ {red}{{total_fmt}}{blue_light} ]{reset}" + ).format(tqdm_header=tqdm_header, length=length, **ansi) + offset = len(f"{{n_fmt:>{length}}}{{total_fmt}}") - 2 * length + pattern = pattern.ljust(30 + offset + get_ansi_len(pattern)) + time_str = self.get_str( + time="{elapsed}<{remaining}", cut_too_long=False) + bar_format = f"{pattern}{{desc}}{time_str}" + iterator = tqdm_class(iterable, leave=False, bar_format=bar_format) + + self.iter_time.reset() + self.data_time.reset() + self.memory.reset() + + end = time.time() + start_time = time.time() + for obj in iterator: + cur_data_time = time.time() - end + self.data_time.update(cur_data_time) + yield obj + cur_iter_time = time.time() - end + self.iter_time.update(cur_iter_time) + if torch.cuda.is_available(): + cur_memory = torch.cuda.max_memory_allocated() / MB + self.memory.update(cur_memory) + if tqdm: + _dict = {k: v for k, v in self.meters.items()} + if verbose > 2 and torch.cuda.is_available(): + _dict.update(memory=f"{cur_memory:.0f} MB") + if verbose > 1: + _dict.update( + iter=f"{cur_iter_time:.3f} s", data=f"{cur_data_time:.3f} s" + ) + iterator.set_description_str( + self.get_str(**_dict, strip=False)) + end = time.time() + self.synchronize_between_processes() + total_time = time.time() - start_time + total_time_str = tqdm_class.format_interval(total_time) + + _dict = {k: v for k, v in self.meters.items()} + if verbose > 2 and torch.cuda.is_available(): + _dict.update(memory=f"{str(self.memory)} MB") + if verbose > 1: + _dict.update( + iter=f"{str(self.iter_time)} s", data=f"{str(self.data_time)} s" + ) + _dict.update(time=total_time_str) + prints(self.delimiter.join( + [header, self.get_str(**_dict)]), indent=indent) + class Logger: def __init__(self, logger=logging.getLogger(Path(__file__).parent.name)): diff --git a/examples/finetune/finetune.py b/examples/finetune/finetune.py new file mode 100644 index 000000000..967f6774c --- /dev/null +++ b/examples/finetune/finetune.py @@ -0,0 +1,195 @@ +""" +CUDA_VISIBLE_DEVICES=0 python examples/finetune/finetune.py --color --save_folder ./saved_models --data_path Bekki.list --tar_path data/Xz.tar --batch_size 32 --epochs 10 --train_module dvae +CUDA_VISIBLE_DEVICES=0 python examples/finetune/finetune.py --color --save_folder ./saved_models --data_path Bekki.list --tar_path data/Xz.tar --batch_size 16 --epochs 10 --train_module gpt_speaker + +--gpt_lora --tar_in_memory --process_ahead + +""" # noqa: E501 + +import argparse +import logging +import os + +import torch.nn +import numpy as np + +import ChatTTS +import ChatTTS.model.gpt +import ChatTTS.model.dvae +from ChatTTS.train.dataset import XzListTar +from ChatTTS.train.model import TrainModule, train_autoencoder, train_gpt + +from tools.normalizer import load_normalizer + +logging.basicConfig(level=logging.ERROR) + + +def main(): + parser = argparse.ArgumentParser(description="ChatTTS demo Launch") + parser.add_argument( + "--data_path", + type=str, + default="dummy_data/xz_list_style/speaker_A.list", + help="the data_path to json/list file", + ) + parser.add_argument("--tar_path", type=str, help="the tarball path with wavs") + parser.add_argument( + "--tar_in_memory", action="store_true", help="load tarball in memory" + ) + parser.add_argument( + "--process_ahead", + action="store_true", + help="process all data ahead during dataset initialization", + ) + parser.add_argument( + "--train_module", + type=str, + default="gpt", + choices=[ + "gpt_all", + "gpt_speaker", + "gpt", + "speaker", + "dvae", + "dvae_encoder", + "dvae_decoder", + "decoder", + ], + ) + parser.add_argument("--train_text", action="store_true", help="train text loss") + parser.add_argument("--gpt_lora", action="store_true", help="train gpt with lora") + # parser.add_argument('--gpt_kbit', type=int, default=16, help='train gpt with kbit') + parser.add_argument("--dvae_path", type=str) + parser.add_argument("--decoder_path", type=str) + parser.add_argument("--gpt_path", type=str) + parser.add_argument("--speaker_embeds_path", type=str) + parser.add_argument("--save_folder", type=str, default="./") + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--color", action="store_true", help="colorful output") + args = parser.parse_args() + data_path: str = args.data_path + tar_path: str | None = args.tar_path + tar_in_memory: bool = args.tar_in_memory + process_ahead: bool = args.process_ahead + train_module: TrainModule = args.train_module + train_text: bool = args.train_text + gpt_lora: bool = args.gpt_lora + # gpt_kbit: int = args.gpt_kbit + save_folder: str = args.save_folder + batch_size: int = args.batch_size + epochs: int = args.epochs + + decoder_path: str = args.decoder_path + dvae_path: str = args.dvae_path + gpt_path: str = args.gpt_path + speaker_embeds_path: str = args.speaker_embeds_path + + chat = ChatTTS.Chat() + chat.load(compile=False) + # load pretrained models + if decoder_path is not None: + chat.decoder.load_state_dict(torch.load(decoder_path, map_location=chat.device)) + if dvae_path is not None: + chat.dvae.load_state_dict(torch.load(dvae_path, map_location=chat.device)) + if gpt_path is not None: + chat.gpt.load_state_dict(torch.load(gpt_path, map_location=chat.device)) + speaker_embeds: dict[str, torch.Tensor] = {} + if speaker_embeds_path is not None: + np_speaker_embeds: dict[str, np.ndarray] = np.load(speaker_embeds_path) + speaker_embeds = { + speaker: torch.from_numpy(speaker_embed).to(chat.device) + for speaker, speaker_embed in np_speaker_embeds.items() + } + + if train_module in [TrainModule.GPT_SPEAKER, TrainModule.GPT]: + if gpt_lora: + import peft + + # match gpt_kbit: + # case 4: + # quantization_config = transformers.BitsAndBytesConfig( + # load_in_4bit=True, + # bnb_4bit_quant_type="nf4", + # bnb_4bit_use_double_quant=True, + # bnb_4bit_compute_dtype=torch.bfloat16, + # ) + # case 8: + # quantization_config = transformers.BitsAndBytesConfig( + # load_in_8bit=True, + # ) + # chat.gpt.gpt = transformers.LlamaModel.from_pretrained() + # peft.prepare_model_for_gpt_kbit_training(chat.gpt.gpt) + lora_config = peft.LoraConfig(r=8, lora_alpha=16) + chat.gpt.gpt = peft.get_peft_model(chat.gpt.gpt, lora_config) + + match train_module: + case ( + TrainModule.GPT_ALL + | TrainModule.GPT_SPEAKER + | TrainModule.GPT + | TrainModule.SPEAKER + | TrainModule.DECODER + ): + train = train_gpt + kwargs = {"train_text": train_text, "speaker_embeds": speaker_embeds} + case TrainModule.DVAE | TrainModule.DVAE_ENCODER | TrainModule.DVAE_DECODER: + train = train_autoencoder + kwargs = {} + case _: + raise ValueError(f"invalid train_module: {train_module}") + + load_normalizer(chat) + + dataset = XzListTar( + root=data_path, + tokenizer=chat.tokenizer._tokenizer, + normalizer=chat.normalizer, + tar_path=tar_path, + tar_in_memory=tar_in_memory, + process_ahead=process_ahead, + # speakers=None, # set(['speaker_A', 'speaker_B']) + ) + speaker_embeds = train( + chat=chat, + dataset=dataset, + train_module=train_module, + batch_size=batch_size, + epochs=epochs, + **kwargs, + ) + + if not os.path.isdir(save_folder): + os.makedirs(save_folder) + gpt_save_path = os.path.join(save_folder, "gpt.pth") + speaker_embeds_save_path = os.path.join(save_folder, "speaker_embeds.npz") + decoder_save_path = os.path.join(save_folder, "decoder.pth") + dvae_save_path = os.path.join(save_folder, "dvae.pth") + if train_module in [TrainModule.GPT_SPEAKER, TrainModule.GPT] and gpt_lora: + chat.gpt.gpt = chat.gpt.gpt.merge_and_unload() + if speaker_embeds is not None: + np_speaker_embeds = { + speaker: speaker_embed.detach().cpu().numpy() + for speaker, speaker_embed in speaker_embeds.items() + } + match train_module: + case TrainModule.GPT_ALL: + torch.save(chat.gpt.state_dict(), gpt_save_path) + torch.save(chat.decoder.state_dict(), decoder_save_path) + np.savez(speaker_embeds_save_path, **np_speaker_embeds) + case TrainModule.GPT_SPEAKER: + torch.save(chat.gpt.state_dict(), gpt_save_path) + np.savez(speaker_embeds_save_path, **np_speaker_embeds) + case TrainModule.GPT: + torch.save(chat.gpt.state_dict(), gpt_save_path) + case TrainModule.DECODER: + torch.save(chat.decoder.state_dict(), decoder_save_path) + case TrainModule.SPEAKER: + np.savez(speaker_embeds_save_path, **np_speaker_embeds) + case TrainModule.DVAE | TrainModule.DVAE_ENCODER | TrainModule.DVAE_DECODER: + torch.save(chat.dvae.state_dict(), dvae_save_path) + print("save models to:", save_folder) + + +if __name__ == "__main__": + main() diff --git a/examples/finetune/infer_autoencoder.py b/examples/finetune/infer_autoencoder.py new file mode 100644 index 000000000..d23f9f3e0 --- /dev/null +++ b/examples/finetune/infer_autoencoder.py @@ -0,0 +1,98 @@ +""" +CUDA_VISIBLE_DEVICES=0 python examples/finetune/infer_autoencoder.py --data_path Bekki.list --tar_path data/Xz.tar +--dvae_path saved_models/dvae.pth +""" # noqa: E501 + +import argparse +import os + +import torch.utils.data +import torch.nn +import torchaudio + +import ChatTTS +import ChatTTS.model.gpt +import ChatTTS.model.dvae +from ChatTTS.train.dataset import XzListTar, AudioCollator +from ChatTTS.train.model import ( + get_mel_specs, + get_mel_attention_mask, + get_dvae_mel_specs, +) + + +def main(): + parser = argparse.ArgumentParser(description="ChatTTS demo Launch") + parser.add_argument("--save_path", type=str, default="./") + parser.add_argument( + "--data_path", + type=str, + default="dummy_data/xz_list_style/speaker_A.list", + help="the data_path to json/list file", + ) + parser.add_argument("--tar_path", type=str, help="the tarball path with wavs") + parser.add_argument("--dvae_path", type=str) + args = parser.parse_args() + save_path: str = args.save_path + data_path: str = args.data_path + tar_path: str | None = args.tar_path + dvae_path: str = args.dvae_path + + chat = ChatTTS.Chat() + chat.load(compile=False) + if dvae_path is not None: + chat.dvae.load_state_dict(torch.load(dvae_path, map_location=chat.device)) + + dataset = XzListTar( + root=data_path, + tokenizer=chat.tokenizer._tokenizer, + normalizer=chat.normalizer, + tar_path=tar_path, + ) + loader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + shuffle=True, + collate_fn=AudioCollator(), + # num_workers=4, + ) + + batch = next(iter(loader)) + waveforms: torch.Tensor = batch["waveforms"] # (batch_size, time) + waveform_attention_mask: torch.Tensor = batch[ + "waveform_attention_mask" + ] # (batch_size, time) + + waveforms = waveforms.to(chat.device, non_blocking=True) + waveform_attention_mask = waveform_attention_mask.to(chat.device, non_blocking=True) + + mel_specs = get_mel_specs(chat, waveforms) # (batch_size, 100, mel_len) + mel_attention_mask = get_mel_attention_mask( + waveform_attention_mask, mel_len=mel_specs.size(2) + ) # (batch_size, mel_len) + mel_specs = mel_specs * mel_attention_mask.unsqueeze(1) # clip + + dvae_mel_specs = get_dvae_mel_specs( + chat, mel_specs, mel_attention_mask + ) # (batch_size, 100, mel_len) + dvae_mel_specs = dvae_mel_specs * mel_attention_mask.unsqueeze(1) # clip + + wav = chat.vocos.decode(dvae_mel_specs).cpu() + org_wav = chat.vocos.decode(mel_specs).cpu() + + print("Original Waveform shape:", org_wav.shape) + print(wav.shape) + torchaudio.save( + os.path.join(save_path, "infer_autoencoder_org.wav"), + org_wav[0].view(1, -1), + sample_rate=24_000, + ) + torchaudio.save( + os.path.join(save_path, "infer_autoencoder.wav"), + wav[0].view(1, -1), + sample_rate=24_000, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/finetune/infer_gpt.py b/examples/finetune/infer_gpt.py new file mode 100644 index 000000000..b22453d9d --- /dev/null +++ b/examples/finetune/infer_gpt.py @@ -0,0 +1,98 @@ +""" +CUDA_VISIBLE_DEVICES=0 python examples/finetune/infer_gpt.py --text "你好,我是恬豆" +--gpt_path ./saved_models/gpt.pth --decoder_path ./saved_models/decoder.pth --speaker_embeds_path ./saved_models/speaker_embeds.npz +""" # noqa: E501 + +import argparse +import os +import random + +import torch.utils.data +import torch.nn +import torchaudio +import numpy as np + +import ChatTTS +import ChatTTS.model.gpt +import ChatTTS.model.dvae + +from tools.normalizer import load_normalizer + + +def main(): + parser = argparse.ArgumentParser(description="ChatTTS demo Launch") + parser.add_argument("--text", type=str, required=True) + parser.add_argument("--speaker", type=str) + parser.add_argument("--save_path", type=str, default="./") + + parser.add_argument("--dvae_path", type=str) + parser.add_argument("--decoder_path", type=str) + parser.add_argument("--gpt_path", type=str) + parser.add_argument("--speaker_embeds_path", type=str) + args = parser.parse_args() + text: str = args.text + speaker: str | None = args.speaker + save_path: str | None = args.save_path + dvae_path: str | None = args.dvae_path + decoder_path: str | None = args.decoder_path + gpt_path: str | None = args.gpt_path + speaker_embeds_path: str | None = args.speaker_embeds_path + + chat = ChatTTS.Chat() + chat.load(compile=False) + # load pretrained models + if decoder_path is not None: + chat.decoder.load_state_dict(torch.load(decoder_path, map_location=chat.device)) + if dvae_path is not None: + chat.dvae.load_state_dict(torch.load(dvae_path, map_location=chat.device)) + if gpt_path is not None: + chat.gpt.load_state_dict(torch.load(gpt_path, map_location=chat.device)) + speaker_embeds: dict[str, torch.Tensor] = {} + if speaker_embeds_path is not None: + np_speaker_embeds: dict[str, np.ndarray] = np.load(speaker_embeds_path) + speaker_embeds = { + speaker: torch.from_numpy(speaker_embed).to(chat.device) + for speaker, speaker_embed in np_speaker_embeds.items() + } + + if speaker is None: + if len(speaker_embeds) == 0: + speaker_embed = chat.speaker._sample_random() + else: + speaker_embed = random.choice(list(speaker_embeds.values())) + else: + speaker_embed = speaker_embeds[speaker] + + load_normalizer(chat) + + decoder_wav = chat.infer( + [text], + stream=False, + params_infer_code=ChatTTS.Chat.InferCodeParams( + spk_emb=chat.speaker._encode(speaker_embed), + ), + ) + print(decoder_wav[0].shape) + torchaudio.save( + os.path.join(save_path, "infer_gpt_decoder.wav"), + torch.from_numpy(decoder_wav[0]).view(1, -1), + sample_rate=24_000, + ) + + dvae_wav = chat.infer( + [text], + stream=False, + params_infer_code=ChatTTS.Chat.InferCodeParams( + spk_emb=chat.speaker._encode(speaker_embed), + ), + ) + print(dvae_wav[0].shape) + torchaudio.save( + os.path.join(save_path, "infer_gpt_dvae.wav"), + torch.from_numpy(dvae_wav[0]).view(1, -1), + sample_rate=24_000, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/finetune/validate.py b/examples/finetune/validate.py new file mode 100644 index 000000000..89d8db698 --- /dev/null +++ b/examples/finetune/validate.py @@ -0,0 +1,104 @@ +""" +CUDA_VISIBLE_DEVICES=0 python examples/finetune/validate.py --color --data_path Bekki.list --tar_path data/Xz.tar --batch_size 16 +--gpt_path ./saved_models/gpt.pth --decoder_path ./saved_models/decoder.pth --speaker_embeds_path ./saved_models/speaker_embeds.npz +--dvae_path ./saved_models/dvae.pth + +--tar_in_memory --process_ahead +""" # noqa: E501 + +import argparse +import logging + +import torch.utils.data +import torch.nn +import torch.nn.functional +from transformers.trainer_pt_utils import LabelSmoother +import numpy as np + +import ChatTTS +import ChatTTS.model.gpt +import ChatTTS.model.dvae +from ChatTTS.train.dataset import XzListTar +from ChatTTS.train.model import train_autoencoder, train_gpt + +from tools.normalizer import load_normalizer + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index +logging.basicConfig(level=logging.ERROR) + + +def main(): + parser = argparse.ArgumentParser(description="ChatTTS demo Launch") + parser.add_argument( + "--data_path", + type=str, + default="dummy_data/xz_list_style/speaker_A.list", + help="the data_path to json/list file", + ) + parser.add_argument("--tar_path", type=str, help="the tarball path with wavs") + parser.add_argument( + "--tar_in_memory", action="store_true", help="load tarball in memory" + ) + parser.add_argument( + "--process_ahead", + action="store_true", + help="process all data ahead during dataset initialization", + ) + # parser.add_argument('--gpt_kbit', type=int, default=16, help='train gpt with kbit') + parser.add_argument("--dvae_path", type=str) + parser.add_argument("--decoder_path", type=str) + parser.add_argument("--gpt_path", type=str) + parser.add_argument("--speaker_embeds_path", type=str) + parser.add_argument("--color", action="store_true", help="colorful output") + args = parser.parse_args() + data_path: str = args.data_path + tar_path: str | None = args.tar_path + tar_in_memory: bool = args.tar_in_memory + process_ahead: bool = args.process_ahead + # gpt_kbit: int = args.gpt_kbit + + decoder_path: str = args.decoder_path + dvae_path: str = args.dvae_path + gpt_path: str = args.gpt_path + speaker_embeds_path: str = args.speaker_embeds_path + + chat = ChatTTS.Chat() + chat.load(compile=False) + # load pretrained models + if decoder_path is not None: + chat.decoder.load_state_dict(torch.load(decoder_path, map_location=chat.device)) + if dvae_path is not None: + chat.dvae.load_state_dict(torch.load(dvae_path, map_location=chat.device)) + if gpt_path is not None: + chat.gpt.load_state_dict(torch.load(gpt_path, map_location=chat.device)) + speaker_embeds: dict[str, torch.Tensor] = {} + if speaker_embeds_path is not None: + np_speaker_embeds: dict[str, np.ndarray] = np.load(speaker_embeds_path) + speaker_embeds = { + speaker: torch.from_numpy(speaker_embed).to(chat.device) + for speaker, speaker_embed in np_speaker_embeds.items() + } + + load_normalizer(chat) + + dataset = XzListTar( + root=data_path, + tokenizer=chat.tokenizer._tokenizer, + normalizer=chat.normalizer, + tar_path=tar_path, + tar_in_memory=tar_in_memory, + process_ahead=process_ahead, + # speakers=None, # set(['speaker_A', 'speaker_B']) + ) + train_autoencoder(chat=chat, dataset=dataset, validate=True) + train_gpt( + chat=chat, + dataset=dataset, + speaker_embeds=speaker_embeds, + train_text=True, + validate=True, + ) + + +if __name__ == "__main__": + main() diff --git a/tools/normalizer/__init__.py b/tools/normalizer/__init__.py index 9a8929bb0..5dc3df1e1 100644 --- a/tools/normalizer/__init__.py +++ b/tools/normalizer/__init__.py @@ -1,2 +1,30 @@ +import ChatTTS +from tools.logger import get_logger + from .en import normalizer_en_nemo_text from .zh import normalizer_zh_tn + + +logger = get_logger("Normalizer") + + +def load_normalizer(chat: ChatTTS.Chat): + # try to load normalizer + try: + chat.normalizer.register("en", normalizer_en_nemo_text()) + except ValueError as e: + logger.error(e) + except BaseException: + logger.warning("Package nemo_text_processing not found!") + logger.warning( + "Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing", + ) + try: + chat.normalizer.register("zh", normalizer_zh_tn()) + except ValueError as e: + logger.error(e) + except BaseException: + logger.warning("Package WeTextProcessing not found!") + logger.warning( + "Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing", + )