Skip to content

Commit

Permalink
Added E2, F5 model types.
Browse files Browse the repository at this point in the history
  • Loading branch information
niknah committed Nov 7, 2024
1 parent 88d4ca8 commit 8004747
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
52 changes: 41 additions & 11 deletions F5TTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from comfy.utils import ProgressBar
from cached_path import cached_path
sys.path.append(Install.f5TTSPath)
from model import DiT # noqa E402
from model import DiT,UNetT # noqa E402
from model.utils_infer import ( # noqa E402
load_model,
preprocess_ref_audio_text,
Expand All @@ -28,6 +28,7 @@

class F5TTSCreate:
voice_reg = re.compile(r"\{(\w+)\}")
model_types = ["F5", "E2"]
tooltip_seed = "Seed. -1 = random"

def is_voice_name(self, word):
Expand All @@ -54,7 +55,33 @@ def load_voice(ref_audio, ref_text):
)
return main_voice

def load_model(self):
def load_model(self, model):
models = {
"F5": self.load_f5_model,
"E2": self.load_e2_model,
}
return models[model]()

def get_vocab_file(self):
return os.path.join(
Install.f5TTSPath, "data/Emilia_ZH_EN_pinyin/vocab.txt"
)

def load_e2_model(self):
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
repo_name = "E2-TTS"
exp_name = "E2TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # noqa E501
vocab_file = self.get_vocab_file()
ema_model = load_model(
model_cls, model_cfg,
ckpt_file, vocab_file
)
return ema_model

def load_f5_model(self):
model_cls = DiT
model_cfg = dict(
dim=1024, depth=22, heads=16,
Expand All @@ -64,10 +91,11 @@ def load_model(self):
exp_name = "F5TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # noqa E501
vocab_file = os.path.join(
Install.f5TTSPath, "data/Emilia_ZH_EN_pinyin/vocab.txt"
vocab_file = self.get_vocab_file()
ema_model = load_model(
model_cls, model_cfg,
ckpt_file, vocab_file
)
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
return ema_model

def generate_audio(self, voices, model_obj, chunks, seed):
Expand Down Expand Up @@ -117,8 +145,8 @@ def generate_audio(self, voices, model_obj, chunks, seed):
os.unlink(wave_file.name)
return audio

def create(self, voices, chunks, seed=-1):
model_obj = self.load_model()
def create(self, voices, chunks, seed=-1, model="F5"):
model_obj = self.load_model(model)
return self.generate_audio(voices, model_obj, chunks, seed)


Expand All @@ -141,6 +169,7 @@ def INPUT_TYPES(s):
"default": 1, "min": -1,
"tooltip": F5TTSCreate.tooltip_seed,
}),
"model": (F5TTSCreate.model_types,),
},
}

Expand Down Expand Up @@ -174,7 +203,7 @@ def remove_wave_file(self):
print("F5TTS: Cannot remove? "+self.wave_file.name)
print(e)

def create(self, sample_audio, sample_text, speech, seed=-1):
def create(self, sample_audio, sample_text, speech, seed=-1, model="F5"):
try:
main_voice = self.load_voice_from_input(sample_audio, sample_text)

Expand All @@ -184,7 +213,7 @@ def create(self, sample_audio, sample_text, speech, seed=-1):
chunks = f5ttsCreate.split_text(speech)
voices['main'] = main_voice

audio = f5ttsCreate.create(voices, chunks, seed)
audio = f5ttsCreate.create(voices, chunks, seed, model)
finally:
self.remove_wave_file()
return (audio, )
Expand Down Expand Up @@ -233,6 +262,7 @@ def INPUT_TYPES(s):
"default": 1, "min": -1,
"tooltip": F5TTSCreate.tooltip_seed,
}),
"model": (F5TTSCreate.model_types,),
}
}

Expand Down Expand Up @@ -289,7 +319,7 @@ def load_voices_from_files(self, sample, voice_names):
voices[voice_name] = self.load_voice_from_file(sample_file)
return voices

def create(self, sample, speech, seed=-1):
def create(self, sample, speech, seed=-1, model="F5"):
# Install.check_install()
main_voice = self.load_voice_from_file(sample)

Expand All @@ -309,7 +339,7 @@ def create(self, sample, speech, seed=-1):
voices = self.load_voices_from_files(sample, voice_names)
voices['main'] = main_voice

audio = f5ttsCreate.create(voices, chunks, seed)
audio = f5ttsCreate.create(voices, chunks, seed, model)
return (audio, )

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-f5-tts"
description = "Text to speech with F5-TTS"
version = "1.0.4"
version = "1.0.5"
license = {text = "MIT License"}

[project.urls]
Expand Down

0 comments on commit 8004747

Please sign in to comment.