add hifigan train

This commit is contained in:
lyuxiang.lx
2024-10-16 11:37:32 +08:00
parent cb200b21c5
commit 789ee9e5e7
13 changed files with 314 additions and 477 deletions

View File

@@ -26,11 +26,13 @@ class CosyVoiceModel:
def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module):
hift: torch.nn.Module,
fp16: bool):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 20
@@ -56,13 +58,17 @@ class CosyVoiceModel:
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
self.llm.to(self.device).eval()
self.llm.half()
if self.fp16 is True:
self.llm.half()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
self.flow.to(self.device).eval()
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device)}
self.hift.load_state_dict(hift_state_dict, strict=False)
self.hift.to(self.device).eval()
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
self.llm.text_encoder = llm_text_encoder
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
@@ -80,6 +86,8 @@ class CosyVoiceModel:
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
if self.fp16 is True:
llm_embedding = llm_embedding.half()
with self.llm_context:
for i in self.llm.inference(text=text.to(self.device),
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
@@ -87,7 +95,7 @@ class CosyVoiceModel:
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device).half()):
embedding=llm_embedding.to(self.device)):
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
@@ -123,7 +131,7 @@ class CosyVoiceModel:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech