mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix lint
This commit is contained in:
@@ -413,7 +413,7 @@ def run_sync_streaming_inference(
|
|||||||
for i in range(1, len(audios)):
|
for i in range(1, len(audios)):
|
||||||
# Cross-fade section
|
# Cross-fade section
|
||||||
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
||||||
audios[i - 1][-cross_fade_samples:] * fade_out)
|
audios[i - 1][-cross_fade_samples:] * fade_out)
|
||||||
# Middle section of the current chunk
|
# Middle section of the current chunk
|
||||||
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
||||||
# Concatenate
|
# Concatenate
|
||||||
|
|||||||
@@ -41,11 +41,11 @@ from transformers import AutoTokenizer
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from matcha.utils.audio import mel_spectrogram
|
from matcha.utils.audio import mel_spectrogram
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
"""Triton Python model for Spark TTS.
|
"""Triton Python model for Spark TTS.
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ class TritonPythonModel:
|
|||||||
parameters = self.model_config['parameters']
|
parameters = self.model_config['parameters']
|
||||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||||
self.logger.log_info(f"model_params:{model_params}")
|
self.logger.log_info(f"model_params:{model_params}")
|
||||||
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
||||||
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
||||||
|
|
||||||
# Initialize tokenizer
|
# Initialize tokenizer
|
||||||
@@ -193,7 +193,6 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
return prompt_speech_tokens
|
return prompt_speech_tokens
|
||||||
|
|
||||||
|
|
||||||
def forward_speaker_embedding(self, wav):
|
def forward_speaker_embedding(self, wav):
|
||||||
"""Forward pass through the speaker embedding component.
|
"""Forward pass through the speaker embedding component.
|
||||||
|
|
||||||
@@ -219,7 +218,6 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
return prompt_spk_embedding
|
return prompt_spk_embedding
|
||||||
|
|
||||||
|
|
||||||
def forward_token2wav(
|
def forward_token2wav(
|
||||||
self,
|
self,
|
||||||
prompt_speech_tokens: torch.Tensor,
|
prompt_speech_tokens: torch.Tensor,
|
||||||
@@ -254,7 +252,6 @@ class TritonPythonModel:
|
|||||||
inputs_tensor.append(token_offset_tensor)
|
inputs_tensor.append(token_offset_tensor)
|
||||||
inputs_tensor.append(finalize_tensor)
|
inputs_tensor.append(finalize_tensor)
|
||||||
|
|
||||||
|
|
||||||
# Create and execute inference request
|
# Create and execute inference request
|
||||||
inference_request = pb_utils.InferenceRequest(
|
inference_request = pb_utils.InferenceRequest(
|
||||||
model_name='token2wav',
|
model_name='token2wav',
|
||||||
@@ -281,7 +278,6 @@ class TritonPythonModel:
|
|||||||
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
def _extract_speech_feat(self, speech):
|
def _extract_speech_feat(self, speech):
|
||||||
speech_feat = mel_spectrogram(
|
speech_feat = mel_spectrogram(
|
||||||
speech,
|
speech,
|
||||||
@@ -299,7 +295,6 @@ class TritonPythonModel:
|
|||||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||||
return speech_feat
|
return speech_feat
|
||||||
|
|
||||||
|
|
||||||
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
|
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
|
||||||
for generated_ids in generated_ids_iter:
|
for generated_ids in generated_ids_iter:
|
||||||
generated_ids = generated_ids.tolist()
|
generated_ids = generated_ids.tolist()
|
||||||
@@ -338,9 +333,8 @@ class TritonPythonModel:
|
|||||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||||
|
|
||||||
|
|
||||||
flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
|
flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
|
||||||
|
|
||||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||||
reference_text = reference_text[0][0].decode('utf-8')
|
reference_text = reference_text[0][0].decode('utf-8')
|
||||||
|
|
||||||
@@ -385,7 +379,9 @@ class TritonPythonModel:
|
|||||||
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
|
|
||||||
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, False)
|
sub_tts_speech = self.forward_token2wav(
|
||||||
|
prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding,
|
||||||
|
this_tts_speech_token, request_id, token_offset, False)
|
||||||
|
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
@@ -413,7 +409,6 @@ class TritonPythonModel:
|
|||||||
else:
|
else:
|
||||||
this_token_hop_len = self.token_hop_len
|
this_token_hop_len = self.token_hop_len
|
||||||
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
||||||
|
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
else:
|
else:
|
||||||
time.sleep(0.02)
|
time.sleep(0.02)
|
||||||
@@ -423,7 +418,7 @@ class TritonPythonModel:
|
|||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
response_sender.send(inference_response)
|
response_sender.send(inference_response)
|
||||||
|
|
||||||
llm_thread.join()
|
llm_thread.join()
|
||||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||||
|
|||||||
@@ -57,13 +57,13 @@ class TritonPythonModel:
|
|||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
|
|
||||||
model_dir = model_params["model_dir"]
|
model_dir = model_params["model_dir"]
|
||||||
gpu="l20"
|
gpu = "l20"
|
||||||
enable_trt = True
|
enable_trt = True
|
||||||
if enable_trt:
|
if enable_trt:
|
||||||
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||||
f'{model_dir}/campplus.onnx',
|
f'{model_dir}/campplus.onnx',
|
||||||
1,
|
1,
|
||||||
False)
|
False)
|
||||||
else:
|
else:
|
||||||
campplus_model = f'{model_dir}/campplus.onnx'
|
campplus_model = f'{model_dir}/campplus.onnx'
|
||||||
option = onnxruntime.SessionOptions()
|
option = onnxruntime.SessionOptions()
|
||||||
@@ -121,7 +121,7 @@ class TritonPythonModel:
|
|||||||
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||||
torch.cuda.current_stream().synchronize()
|
torch.cuda.current_stream().synchronize()
|
||||||
self.spk_model.release_estimator(spk_model, stream)
|
self.spk_model.release_estimator(spk_model, stream)
|
||||||
|
|
||||||
return embedding.half()
|
return embedding.half()
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
@@ -142,7 +142,6 @@ class TritonPythonModel:
|
|||||||
wav_array = torch.from_numpy(wav_array).to(self.device)
|
wav_array = torch.from_numpy(wav_array).to(self.device)
|
||||||
|
|
||||||
embedding = self._extract_spk_embedding(wav_array)
|
embedding = self._extract_spk_embedding(wav_array)
|
||||||
|
|
||||||
|
|
||||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
|
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
|
||||||
"prompt_spk_embedding", to_dlpack(embedding))
|
"prompt_spk_embedding", to_dlpack(embedding))
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ logger = logging.getLogger(__name__)
|
|||||||
ORIGINAL_VOCAB_SIZE = 151663
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
|
||||||
class CosyVoice2:
|
class CosyVoice2:
|
||||||
|
|
||||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
|
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
|
||||||
@@ -123,7 +124,6 @@ class CosyVoice2Model:
|
|||||||
input_names = ["x", "mask", "mu", "cond"]
|
input_names = ["x", "mask", "mu", "cond"]
|
||||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
|
|
||||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||||
with torch.cuda.amp.autocast(self.fp16):
|
with torch.cuda.amp.autocast(self.fp16):
|
||||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||||
|
|||||||
Reference in New Issue
Block a user