mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
clean code
This commit is contained in:
@@ -43,7 +43,7 @@ import torchaudio
|
||||
|
||||
|
||||
from matcha.utils.audio import mel_spectrogram
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
@@ -85,9 +85,7 @@ class TritonPythonModel:
|
||||
self.model_config = json.loads(args['model_config'])
|
||||
parameters = self.model_config['parameters']
|
||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||
self.logger.log_info(f"model_params:{model_params}")
|
||||
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
||||
# self.dynamic_chunk_strategy = "equal"
|
||||
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
||||
|
||||
# Initialize tokenizer
|
||||
@@ -103,12 +101,8 @@ class TritonPythonModel:
|
||||
self.flow_pre_lookahead_len = 3
|
||||
self.token_hop_len = 15
|
||||
|
||||
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
|
||||
if not os.path.exists(spk_info_path):
|
||||
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||
self.default_spk_info = spk_info["001"]
|
||||
self.http_client = httpx.AsyncClient()
|
||||
self.api_base = "http://localhost:8000/v1/chat/completions"
|
||||
|
||||
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
|
||||
"""Converts a tensor or list of speech token IDs to a string representation."""
|
||||
@@ -147,12 +141,8 @@ class TritonPythonModel:
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
api_base = "http://localhost:8000/v1/chat/completions"
|
||||
|
||||
buffer = ""
|
||||
async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
|
||||
print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||
print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||
async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
@@ -164,7 +154,6 @@ class TritonPythonModel:
|
||||
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
|
||||
if content:
|
||||
buffer += content
|
||||
print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||
while True:
|
||||
match = re.search(r"<\|s_(\d+)\|>", buffer)
|
||||
if not match:
|
||||
@@ -307,40 +296,24 @@ class TritonPythonModel:
|
||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||
|
||||
# Process reference audio through audio tokenizer
|
||||
if wav is not None:
|
||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||
|
||||
wav_tensor = wav.as_numpy()
|
||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||
print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
# prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||
wav_tensor = wav.as_numpy()
|
||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||
|
||||
# reference_text = self.default_spk_info["prompt_text"]
|
||||
# prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||
# prompt_speech_feat = None
|
||||
# prompt_spk_embedding = None
|
||||
|
||||
else:
|
||||
# using pre-cached reference text
|
||||
assert False, "using pre-cached reference text is not supported"
|
||||
reference_text = self.default_spk_info["prompt_text"]
|
||||
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||
prompt_speech_feat = None
|
||||
prompt_spk_embedding = None
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
|
||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||
target_text = target_text[0][0].decode('utf-8')
|
||||
print(f"target_text: {target_text}, time: {datetime.now()}")
|
||||
|
||||
if self.decoupled:
|
||||
response_sender = request.get_response_sender()
|
||||
@@ -349,7 +322,6 @@ class TritonPythonModel:
|
||||
token_offset, chunk_index = 0, 0
|
||||
start_time = time.time()
|
||||
this_token_hop_len = self.token_hop_len
|
||||
print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||
async for generated_ids in self.forward_llm_async(
|
||||
target_text=target_text,
|
||||
reference_text=reference_text,
|
||||
@@ -358,24 +330,20 @@ class TritonPythonModel:
|
||||
if not generated_ids:
|
||||
break
|
||||
semantic_token_ids_arr.append(generated_ids)
|
||||
print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||
while True:
|
||||
pending_num = len(semantic_token_ids_arr) - token_offset
|
||||
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
||||
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||
sub_tts_speech = await self.forward_token2wav(
|
||||
chunk_index,
|
||||
this_tts_speech_token, request_id, wav, wav_len, False
|
||||
)
|
||||
print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
token_offset += this_token_hop_len
|
||||
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
|
||||
|
||||
if self.dynamic_chunk_strategy == "exponential":
|
||||
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
||||
@@ -389,7 +357,6 @@ class TritonPythonModel:
|
||||
avg_chunk_processing_time = cost_time / (chunk_index + 1)
|
||||
if avg_chunk_processing_time > 0:
|
||||
multiples = (duration - cost_time) / avg_chunk_processing_time
|
||||
self.logger.log_info(f"multiples: {multiples}")
|
||||
next_pending_num = len(semantic_token_ids_arr) - token_offset
|
||||
if multiples > 4:
|
||||
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
|
||||
@@ -409,9 +376,8 @@ class TritonPythonModel:
|
||||
response_sender.send(inference_response)
|
||||
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||
else:
|
||||
raise NotImplementedError("Decoupled mode is not supported")
|
||||
raise NotImplementedError("Offline TTS mode is not supported")
|
||||
|
||||
async def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
@@ -106,13 +106,10 @@ class TritonPythonModel:
|
||||
# Process each request in batch
|
||||
for request in requests:
|
||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
|
||||
# shift the speech tokens according to the original vocab size
|
||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)
|
||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
||||
|
||||
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
||||
|
||||
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||
|
||||
request_id = request.request_id()
|
||||
@@ -124,23 +121,14 @@ class TritonPythonModel:
|
||||
request, "reference_wav_len").as_numpy().item()
|
||||
|
||||
wav_array = torch.from_numpy(wav_array)
|
||||
# Prepare inputs
|
||||
wav = wav_array[:, :wav_len].squeeze(0)
|
||||
|
||||
spk_id = get_spk_id_from_prompt_audio(wav)
|
||||
# wav = wav.to(self.device)
|
||||
|
||||
# update cache before forward
|
||||
# self.token2wav_model.streaming_flow_cache[request_id]
|
||||
# self.token2wav_model.hift_cache_dict[request_id]
|
||||
|
||||
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
|
||||
|
||||
# get the cache after forward
|
||||
outputs = []
|
||||
|
||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||
|
||||
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
|
||||
outputs.append(wav_tensor)
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
|
||||
|
||||
@@ -320,7 +320,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def forward(
|
||||
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
):
|
||||
# assert all item in prompt_audios_sample_rate is 16000
|
||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||
|
||||
|
||||
@@ -335,7 +334,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def prepare_prompt_audio(
|
||||
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
):
|
||||
# assert all item in prompt_audios_sample_rate is 16000
|
||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||
|
||||
|
||||
@@ -385,7 +383,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
|
||||
cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||
self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
|
||||
print(f"speaker_id {speaker_id} added to cache")
|
||||
|
||||
if request_id not in self.streaming_flow_cache:
|
||||
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
||||
@@ -394,12 +391,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
source = torch.zeros(1, 1, 0, device='cuda'),
|
||||
speech = torch.zeros(1, 0, device='cuda'),
|
||||
)
|
||||
# else:
|
||||
# for k, v in self.streaming_flow_cache[request_id].items():
|
||||
# print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}")
|
||||
# for k, v in self.hift_cache_dict[request_id].items():
|
||||
# print(f"k: {k}, v: {v.shape}, dtype: {v.dtype}")
|
||||
# breakpoint()
|
||||
|
||||
current_request_cache = self.streaming_flow_cache[request_id]
|
||||
|
||||
@@ -477,7 +468,6 @@ def get_args():
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
||||
# mkdir output_dir if not exists
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
dataset_name = "yuekai/seed_tts_cosy2"
|
||||
|
||||
Reference in New Issue
Block a user