mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
update vc/tts code
This commit is contained in:
@@ -35,7 +35,7 @@ class CosyVoiceModel:
|
||||
self.token_max_hop_len = 200
|
||||
self.token_overlap_len = 20
|
||||
# mel fade in out
|
||||
self.mel_overlap_len = 34
|
||||
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
||||
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
||||
# hift cache
|
||||
self.mel_cache_len = 20
|
||||
@@ -54,9 +54,10 @@ class CosyVoiceModel:
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
||||
self.llm.to(self.device).eval()
|
||||
self.llm.half()
|
||||
if self.llm is not None:
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
||||
self.llm.to(self.device).eval()
|
||||
self.llm.half()
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
||||
self.flow.to(self.device).eval()
|
||||
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
||||
@@ -131,11 +132,11 @@ class CosyVoiceModel:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
return tts_speech
|
||||
|
||||
def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
||||
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
||||
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
||||
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
||||
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
@@ -148,7 +149,8 @@ class CosyVoiceModel:
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
||||
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
||||
.unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
@@ -164,7 +166,7 @@ class CosyVoiceModel:
|
||||
break
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
@@ -175,7 +177,58 @@ class CosyVoiceModel:
|
||||
else:
|
||||
# deal with all tokens
|
||||
p.join()
|
||||
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
|
||||
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
||||
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
|
||||
if stream is True:
|
||||
token_hop_len = self.token_min_hop_len
|
||||
while True:
|
||||
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
||||
.unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=False)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
||||
# increase token_hop_len for better speech quality
|
||||
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
||||
break
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
|
||||
@@ -125,7 +125,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
||||
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
|
||||
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
||||
|
||||
@@ -49,13 +49,14 @@ class InterpolateRegulator(nn.Module):
|
||||
olens = ylens
|
||||
return out * mask, olens
|
||||
|
||||
def inference(self, x1, x2, mel_len1, mel_len2):
|
||||
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
||||
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
||||
# x in (B, T, D)
|
||||
if x2.shape[1] > 40:
|
||||
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
|
||||
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
|
||||
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
|
||||
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
||||
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
||||
mode='linear')
|
||||
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
||||
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
||||
else:
|
||||
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import base64
|
||||
import os
|
||||
import string
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from whisper.tokenizer import Tokenizer
|
||||
|
||||
import tiktoken
|
||||
|
||||
@@ -145,6 +145,7 @@ def fade_in_out(fade_in_mel, fade_out_mel, window):
|
||||
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||
return fade_in_mel.to(device)
|
||||
|
||||
|
||||
def set_all_random_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
Reference in New Issue
Block a user