mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update readme
This commit is contained in:
@@ -32,7 +32,7 @@ import triton_python_backend_utils as pb_utils
|
||||
import os
|
||||
import numpy as np
|
||||
import s3tokenizer
|
||||
|
||||
torch.set_num_threads(1)
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -42,6 +44,7 @@ import torchaudio
|
||||
|
||||
from matcha.utils.audio import mel_spectrogram
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for Spark TTS.
|
||||
@@ -62,6 +65,8 @@ class TritonPythonModel:
|
||||
parameters = self.model_config['parameters']
|
||||
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
||||
self.logger.log_info(f"model_params:{model_params}")
|
||||
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
|
||||
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
|
||||
|
||||
# Initialize tokenizer
|
||||
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
||||
@@ -72,6 +77,10 @@ class TritonPythonModel:
|
||||
self.device = torch.device("cuda")
|
||||
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
||||
|
||||
self.token_frame_rate = 25
|
||||
self.flow_pre_lookahead_len = 3
|
||||
self.token_hop_len = 15
|
||||
|
||||
def forward_llm(self, input_ids):
|
||||
"""
|
||||
Prepares the response from the language model based on the provided
|
||||
@@ -99,7 +108,7 @@ class TritonPythonModel:
|
||||
"""
|
||||
# convert input_ids to numpy, with shape [1, sequence_length]
|
||||
input_ids = input_ids.cpu().numpy()
|
||||
max_tokens = 1024
|
||||
max_tokens = 750
|
||||
input_dict = {
|
||||
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
||||
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
|
||||
@@ -109,6 +118,7 @@ class TritonPythonModel:
|
||||
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
||||
"temperature": np.array([[0.8]], dtype=np.float32),
|
||||
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
|
||||
"random_seed": np.array([[42]], dtype=np.uint64),
|
||||
"input_ids": input_ids,
|
||||
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
||||
}
|
||||
@@ -139,7 +149,6 @@ class TritonPythonModel:
|
||||
|
||||
# Get actual output IDs up to the sequence length
|
||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
||||
print(f"actual_output_ids: {actual_output_ids}")
|
||||
|
||||
yield actual_output_ids
|
||||
else:
|
||||
@@ -290,6 +299,15 @@ class TritonPythonModel:
|
||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||
return speech_feat
|
||||
|
||||
|
||||
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
|
||||
for generated_ids in generated_ids_iter:
|
||||
generated_ids = generated_ids.tolist()
|
||||
if len(generated_ids) == 0:
|
||||
break
|
||||
semantic_token_ids_arr.extend(generated_ids)
|
||||
llm_is_done_flag[0] = True
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
@@ -322,9 +340,7 @@ class TritonPythonModel:
|
||||
|
||||
|
||||
flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
|
||||
token_hop_len = 25
|
||||
flow_pre_lookahead_len = 3
|
||||
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
|
||||
@@ -340,47 +356,75 @@ class TritonPythonModel:
|
||||
|
||||
# Generate semantic tokens with LLM
|
||||
generated_ids_iter = self.forward_llm(input_ids)
|
||||
|
||||
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||
print(f"here2")
|
||||
|
||||
if self.decoupled:
|
||||
response_sender = request.get_response_sender()
|
||||
|
||||
semantic_token_ids_arr = []
|
||||
llm_is_done_flag = [False]
|
||||
|
||||
llm_thread = threading.Thread(
|
||||
target=self._llm_gen_thread,
|
||||
args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
|
||||
)
|
||||
|
||||
semantic_token_ids_arr, token_offset = [], 0
|
||||
for generated_ids in generated_ids_iter:
|
||||
llm_thread.start()
|
||||
|
||||
generated_ids = generated_ids.tolist()
|
||||
print(f"generated_id: {generated_ids}")
|
||||
semantic_token_ids_arr.extend(generated_ids)
|
||||
token_offset, chunk_index = 0, 0
|
||||
start_time = time.time()
|
||||
this_token_hop_len = self.token_hop_len
|
||||
|
||||
prompt_token_pad = int(np.ceil(flow_prompt_speech_token_len / token_hop_len) * token_hop_len - flow_prompt_speech_token_len)
|
||||
this_token_hop_len = token_hop_len + prompt_token_pad if token_offset == 0 else token_hop_len
|
||||
print(f"this_token_hop_len: {this_token_hop_len}")
|
||||
if len(semantic_token_ids_arr) - token_offset >= this_token_hop_len + flow_pre_lookahead_len:
|
||||
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + flow_pre_lookahead_len]
|
||||
print(f"this_tts_speech_token: {this_tts_speech_token}")
|
||||
while True:
|
||||
pending_num = len(semantic_token_ids_arr) - token_offset
|
||||
|
||||
if llm_is_done_flag[0]:
|
||||
break
|
||||
|
||||
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
|
||||
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
|
||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
print(f"here3")
|
||||
|
||||
|
||||
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, False)
|
||||
print(f"here4")
|
||||
# Prepare response to send
|
||||
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
self.logger.log_info(f"[{request_id}]")
|
||||
token_offset += this_token_hop_len
|
||||
print(f"here")
|
||||
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
|
||||
|
||||
if self.dynamic_chunk_strategy == "exponential":
|
||||
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
|
||||
elif self.dynamic_chunk_strategy == "time_based":
|
||||
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
|
||||
cost_time = time.time() - start_time
|
||||
duration = token_offset / self.token_frame_rate
|
||||
if chunk_index > 0 and cost_time > 0:
|
||||
avg_chunk_processing_time = cost_time / (chunk_index + 1)
|
||||
if avg_chunk_processing_time > 0:
|
||||
multiples = (duration - cost_time) / avg_chunk_processing_time
|
||||
self.logger.log_info(f"multiples: {multiples}")
|
||||
next_pending_num = len(semantic_token_ids_arr) - token_offset
|
||||
if multiples > 4:
|
||||
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
|
||||
elif multiples > 2:
|
||||
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
|
||||
else:
|
||||
this_token_hop_len = self.token_hop_len
|
||||
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
|
||||
|
||||
chunk_index += 1
|
||||
else:
|
||||
time.sleep(0.02)
|
||||
|
||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
|
||||
|
||||
llm_thread.join()
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||
else:
|
||||
|
||||
@@ -47,11 +47,11 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(level
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
class CosyVoice2:
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
|
||||
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
@@ -61,7 +61,7 @@ class CosyVoice2:
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
|
||||
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
|
||||
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
@@ -77,8 +77,9 @@ class CosyVoice2Model:
|
||||
def __init__(self,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
fp16: bool = False,
|
||||
device: str = 'cuda'):
|
||||
self.device = device
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
@@ -179,11 +180,11 @@ class TritonPythonModel:
|
||||
model_dir = model_params["model_dir"]
|
||||
|
||||
# Initialize device and vocoder
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
||||
|
||||
self.token2wav_model = CosyVoice2(
|
||||
model_dir, load_jit=True, load_trt=True, fp16=True
|
||||
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
|
||||
)
|
||||
|
||||
logger.info("Token2Wav initialized successfully")
|
||||
@@ -224,7 +225,6 @@ class TritonPythonModel:
|
||||
else:
|
||||
stream = False
|
||||
request_id = request.request_id()
|
||||
print(f"token_offset: {token_offset}, finalize: {finalize}, request_id: {request_id}")
|
||||
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
|
||||
prompt_token=prompt_speech_tokens,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
@@ -234,7 +234,6 @@ class TritonPythonModel:
|
||||
stream=stream,
|
||||
finalize=finalize)
|
||||
if finalize:
|
||||
print(f"dict keys: {self.token2wav_model.model.hift_cache_dict.keys()}")
|
||||
self.token2wav_model.model.hift_cache_dict.pop(request_id)
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user