mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add prompt audio cache
This commit is contained in:
@@ -77,16 +77,19 @@ The following results were obtained by decoding on a single L20 GPU with 26 prom
|
|||||||
**Streaming TTS (First Chunk Latency)**
|
**Streaming TTS (First Chunk Latency)**
|
||||||
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||||
|---|---|---|---|---|
|
|---|---|---|---|---|
|
||||||
| Streaming, Decoupled=True | 1 | 220.43 | 218.07 | 0.1237 |
|
| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
|
||||||
| Streaming, Decoupled=True | 2 | 476.97 | 369.25 | 0.1022 |
|
| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
|
||||||
| Streaming, Decoupled=True | 4 | 1107.34 | 1243.75| 0.0922 |
|
| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
|
||||||
|
| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
|
||||||
|
| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
|
||||||
|
| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
|
||||||
|
|
||||||
**Offline TTS (Full Sentence Latency)**
|
**Offline TTS (Full Sentence Latency)**
|
||||||
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||||
|---|---|---|---|---|---|
|
|---|---|---|---|---|---|
|
||||||
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
|
||||||
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
|
||||||
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
|
||||||
|
|
||||||
### OpenAI-Compatible Server
|
### OpenAI-Compatible Server
|
||||||
|
|
||||||
|
|||||||
@@ -257,7 +257,13 @@ def get_args():
|
|||||||
default=0.1,
|
default=0.1,
|
||||||
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
||||||
)
|
)
|
||||||
# --- End Added arguments ---
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-spk2info-cache",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="Use spk2info cache for reference audio.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@@ -283,7 +289,8 @@ def prepare_request_input_output(
|
|||||||
reference_text,
|
reference_text,
|
||||||
target_text,
|
target_text,
|
||||||
sample_rate=16000,
|
sample_rate=16000,
|
||||||
padding_duration: int = None # Optional padding for offline mode
|
padding_duration: int = None, # Optional padding for offline mode
|
||||||
|
use_spk2info_cache: bool = False
|
||||||
):
|
):
|
||||||
"""Prepares inputs for Triton inference (offline or streaming)."""
|
"""Prepares inputs for Triton inference (offline or streaming)."""
|
||||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||||
@@ -330,7 +337,8 @@ def prepare_request_input_output(
|
|||||||
inputs[3].set_data_from_numpy(input_data_numpy)
|
inputs[3].set_data_from_numpy(input_data_numpy)
|
||||||
|
|
||||||
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
||||||
|
if use_spk2info_cache:
|
||||||
|
inputs = inputs[-1:]
|
||||||
return inputs, outputs
|
return inputs, outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -453,6 +461,7 @@ async def send_streaming(
|
|||||||
save_sample_rate: int = 16000,
|
save_sample_rate: int = 16000,
|
||||||
chunk_overlap_duration: float = 0.1,
|
chunk_overlap_duration: float = 0.1,
|
||||||
padding_duration: int = None,
|
padding_duration: int = None,
|
||||||
|
use_spk2info_cache: bool = False,
|
||||||
):
|
):
|
||||||
total_duration = 0.0
|
total_duration = 0.0
|
||||||
latency_data = []
|
latency_data = []
|
||||||
@@ -478,7 +487,8 @@ async def send_streaming(
|
|||||||
reference_text,
|
reference_text,
|
||||||
target_text,
|
target_text,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
padding_duration=padding_duration
|
padding_duration=padding_duration,
|
||||||
|
use_spk2info_cache=use_spk2info_cache
|
||||||
)
|
)
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
user_data = UserData()
|
user_data = UserData()
|
||||||
@@ -534,6 +544,7 @@ async def send(
|
|||||||
padding_duration: int = None,
|
padding_duration: int = None,
|
||||||
audio_save_dir: str = "./",
|
audio_save_dir: str = "./",
|
||||||
save_sample_rate: int = 16000,
|
save_sample_rate: int = 16000,
|
||||||
|
use_spk2info_cache: bool = False,
|
||||||
):
|
):
|
||||||
total_duration = 0.0
|
total_duration = 0.0
|
||||||
latency_data = []
|
latency_data = []
|
||||||
@@ -552,7 +563,8 @@ async def send(
|
|||||||
reference_text,
|
reference_text,
|
||||||
target_text,
|
target_text,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
padding_duration=padding_duration
|
padding_duration=padding_duration,
|
||||||
|
use_spk2info_cache=use_spk2info_cache
|
||||||
)
|
)
|
||||||
sequence_id = 100000000 + i + task_id * 10
|
sequence_id = 100000000 + i + task_id * 10
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -691,6 +703,7 @@ async def main():
|
|||||||
audio_save_dir=args.log_dir,
|
audio_save_dir=args.log_dir,
|
||||||
padding_duration=1,
|
padding_duration=1,
|
||||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||||
|
use_spk2info_cache=args.use_spk2info_cache,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif args.mode == "streaming":
|
elif args.mode == "streaming":
|
||||||
@@ -706,6 +719,7 @@ async def main():
|
|||||||
padding_duration=10,
|
padding_duration=10,
|
||||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||||
chunk_overlap_duration=args.chunk_overlap_duration,
|
chunk_overlap_duration=args.chunk_overlap_duration,
|
||||||
|
use_spk2info_cache=args.use_spk2info_cache,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# --- End Task Creation ---
|
# --- End Task Creation ---
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ import torchaudio
|
|||||||
|
|
||||||
from matcha.utils.audio import mel_spectrogram
|
from matcha.utils.audio import mel_spectrogram
|
||||||
|
|
||||||
|
ORIGINAL_VOCAB_SIZE = 151663
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
|
||||||
@@ -81,6 +82,12 @@ class TritonPythonModel:
|
|||||||
self.flow_pre_lookahead_len = 3
|
self.flow_pre_lookahead_len = 3
|
||||||
self.token_hop_len = 15
|
self.token_hop_len = 15
|
||||||
|
|
||||||
|
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
|
||||||
|
if not os.path.exists(spk_info_path):
|
||||||
|
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
||||||
|
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||||
|
self.default_spk_info = spk_info["001"]
|
||||||
|
|
||||||
def forward_llm(self, input_ids):
|
def forward_llm(self, input_ids):
|
||||||
"""
|
"""
|
||||||
Prepares the response from the language model based on the provided
|
Prepares the response from the language model based on the provided
|
||||||
@@ -220,11 +227,11 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
def forward_token2wav(
|
def forward_token2wav(
|
||||||
self,
|
self,
|
||||||
prompt_speech_tokens: torch.Tensor,
|
|
||||||
prompt_speech_feat: torch.Tensor,
|
|
||||||
prompt_spk_embedding: torch.Tensor,
|
|
||||||
target_speech_tokens: torch.Tensor,
|
target_speech_tokens: torch.Tensor,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
prompt_speech_tokens: torch.Tensor = None,
|
||||||
|
prompt_speech_feat: torch.Tensor = None,
|
||||||
|
prompt_spk_embedding: torch.Tensor = None,
|
||||||
token_offset: int = None,
|
token_offset: int = None,
|
||||||
finalize: bool = None) -> torch.Tensor:
|
finalize: bool = None) -> torch.Tensor:
|
||||||
"""Forward pass through the vocoder component.
|
"""Forward pass through the vocoder component.
|
||||||
@@ -238,12 +245,9 @@ class TritonPythonModel:
|
|||||||
Returns:
|
Returns:
|
||||||
Generated waveform tensor
|
Generated waveform tensor
|
||||||
"""
|
"""
|
||||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
|
||||||
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
|
||||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
|
||||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||||
|
|
||||||
inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
inputs_tensor = [target_speech_tokens_tensor]
|
||||||
|
|
||||||
if token_offset is not None:
|
if token_offset is not None:
|
||||||
assert finalize is not None
|
assert finalize is not None
|
||||||
@@ -252,6 +256,13 @@ class TritonPythonModel:
|
|||||||
inputs_tensor.append(token_offset_tensor)
|
inputs_tensor.append(token_offset_tensor)
|
||||||
inputs_tensor.append(finalize_tensor)
|
inputs_tensor.append(finalize_tensor)
|
||||||
|
|
||||||
|
if prompt_spk_embedding is not None:
|
||||||
|
assert prompt_speech_feat is not None
|
||||||
|
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||||
|
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
||||||
|
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||||
|
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
|
||||||
|
|
||||||
# Create and execute inference request
|
# Create and execute inference request
|
||||||
inference_request = pb_utils.InferenceRequest(
|
inference_request = pb_utils.InferenceRequest(
|
||||||
model_name='token2wav',
|
model_name='token2wav',
|
||||||
@@ -318,25 +329,30 @@ class TritonPythonModel:
|
|||||||
request_id = request.request_id()
|
request_id = request.request_id()
|
||||||
# Extract input tensors
|
# Extract input tensors
|
||||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
|
||||||
|
|
||||||
# Process reference audio through audio tokenizer
|
# Process reference audio through audio tokenizer
|
||||||
|
if wav is not None:
|
||||||
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||||
|
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||||
|
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||||
|
|
||||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
wav_tensor = wav.as_numpy()
|
||||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||||
|
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||||
|
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||||
|
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||||
|
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||||
|
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||||
|
|
||||||
wav_tensor = wav.as_numpy()
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
reference_text = reference_text[0][0].decode('utf-8')
|
||||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
else:
|
||||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
# using pre-cached reference text
|
||||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
reference_text = self.default_spk_info["prompt_text"]
|
||||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||||
|
prompt_speech_feat = None
|
||||||
flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
|
prompt_spk_embedding = None
|
||||||
|
|
||||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
|
||||||
reference_text = reference_text[0][0].decode('utf-8')
|
|
||||||
|
|
||||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||||
target_text = target_text[0][0].decode('utf-8')
|
target_text = target_text[0][0].decode('utf-8')
|
||||||
@@ -350,7 +366,6 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
# Generate semantic tokens with LLM
|
# Generate semantic tokens with LLM
|
||||||
generated_ids_iter = self.forward_llm(input_ids)
|
generated_ids_iter = self.forward_llm(input_ids)
|
||||||
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
|
||||||
|
|
||||||
if self.decoupled:
|
if self.decoupled:
|
||||||
response_sender = request.get_response_sender()
|
response_sender = request.get_response_sender()
|
||||||
@@ -380,8 +395,9 @@ class TritonPythonModel:
|
|||||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
|
|
||||||
sub_tts_speech = self.forward_token2wav(
|
sub_tts_speech = self.forward_token2wav(
|
||||||
prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding,
|
this_tts_speech_token, request_id, prompt_speech_tokens,
|
||||||
this_tts_speech_token, request_id, token_offset, False)
|
prompt_speech_feat, prompt_spk_embedding, token_offset, False
|
||||||
|
)
|
||||||
|
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
@@ -414,7 +430,7 @@ class TritonPythonModel:
|
|||||||
time.sleep(0.02)
|
time.sleep(0.02)
|
||||||
|
|
||||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
|
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
response_sender.send(inference_response)
|
response_sender.send(inference_response)
|
||||||
@@ -428,7 +444,7 @@ class TritonPythonModel:
|
|||||||
if generated_ids is None or len(generated_ids) == 0:
|
if generated_ids is None or len(generated_ids) == 0:
|
||||||
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
||||||
|
|
||||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
|
audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
|
||||||
|
|
||||||
# Prepare response
|
# Prepare response
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||||
|
|||||||
@@ -37,16 +37,19 @@ input [
|
|||||||
name: "reference_wav"
|
name: "reference_wav"
|
||||||
data_type: TYPE_FP32
|
data_type: TYPE_FP32
|
||||||
dims: [-1]
|
dims: [-1]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "reference_wav_len"
|
name: "reference_wav_len"
|
||||||
data_type: TYPE_INT32
|
data_type: TYPE_INT32
|
||||||
dims: [1]
|
dims: [1]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "reference_text"
|
name: "reference_text"
|
||||||
data_type: TYPE_STRING
|
data_type: TYPE_STRING
|
||||||
dims: [1]
|
dims: [1]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "target_text"
|
name: "target_text"
|
||||||
|
|||||||
@@ -187,6 +187,12 @@ class TritonPythonModel:
|
|||||||
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
|
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
spk_info_path = os.path.join(model_dir, "spk2info.pt")
|
||||||
|
if not os.path.exists(spk_info_path):
|
||||||
|
raise ValueError(f"spk2info.pt not found in {model_dir}")
|
||||||
|
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||||
|
self.default_spk_info = spk_info["001"]
|
||||||
|
|
||||||
logger.info("Token2Wav initialized successfully")
|
logger.info("Token2Wav initialized successfully")
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
@@ -202,17 +208,23 @@ class TritonPythonModel:
|
|||||||
# Process each request in batch
|
# Process each request in batch
|
||||||
for request in requests:
|
for request in requests:
|
||||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||||
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
|
|
||||||
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
|
|
||||||
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
|
|
||||||
|
|
||||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
|
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
|
||||||
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
|
|
||||||
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
|
||||||
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
if prompt_speech_tokens_tensor is not None:
|
||||||
|
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
|
||||||
|
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
|
||||||
|
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
|
||||||
|
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
|
||||||
|
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
||||||
|
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
||||||
|
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
|
else:
|
||||||
|
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
|
||||||
|
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
|
||||||
|
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
|
||||||
|
|
||||||
# shift the speech tokens according to the original vocab size
|
# shift the speech tokens according to the original vocab size
|
||||||
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
|
||||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
|
|
||||||
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
||||||
|
|||||||
@@ -35,16 +35,19 @@ input [
|
|||||||
name: "prompt_speech_tokens"
|
name: "prompt_speech_tokens"
|
||||||
data_type: TYPE_INT32
|
data_type: TYPE_INT32
|
||||||
dims: [-1]
|
dims: [-1]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "prompt_speech_feat"
|
name: "prompt_speech_feat"
|
||||||
data_type: TYPE_FP16
|
data_type: TYPE_FP16
|
||||||
dims: [-1, 80]
|
dims: [-1, 80]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "prompt_spk_embedding"
|
name: "prompt_spk_embedding"
|
||||||
data_type: TYPE_FP16
|
data_type: TYPE_FP16
|
||||||
dims: [-1]
|
dims: [-1]
|
||||||
|
optional: true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "token_offset"
|
name: "token_offset"
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ trt_engines_dir=./trt_engines_${trt_dtype}
|
|||||||
|
|
||||||
model_repo=./model_repo_cosyvoice2
|
model_repo=./model_repo_cosyvoice2
|
||||||
|
|
||||||
|
use_spk2info_cache=True
|
||||||
|
|
||||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
echo "Cloning CosyVoice"
|
echo "Cloning CosyVoice"
|
||||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
||||||
@@ -27,6 +29,8 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
echo "Downloading CosyVoice2-0.5B"
|
echo "Downloading CosyVoice2-0.5B"
|
||||||
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
||||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
||||||
|
# download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
|
||||||
|
wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
@@ -57,10 +61,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
cosyvoice2_dir="cosyvoice2"
|
cosyvoice2_dir="cosyvoice2"
|
||||||
|
|
||||||
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
||||||
cp -r ./model_repo/audio_tokenizer $model_repo
|
|
||||||
cp -r ./model_repo/tensorrt_llm $model_repo
|
cp -r ./model_repo/tensorrt_llm $model_repo
|
||||||
cp -r ./model_repo/token2wav $model_repo
|
cp -r ./model_repo/token2wav $model_repo
|
||||||
cp -r ./model_repo/speaker_embedding $model_repo
|
if [ $use_spk2info_cache == "False" ]; then
|
||||||
|
cp -r ./model_repo/audio_tokenizer $model_repo
|
||||||
|
cp -r ./model_repo/speaker_embedding $model_repo
|
||||||
|
fi
|
||||||
|
|
||||||
ENGINE_PATH=$trt_engines_dir
|
ENGINE_PATH=$trt_engines_dir
|
||||||
MAX_QUEUE_DELAY_MICROSECONDS=0
|
MAX_QUEUE_DELAY_MICROSECONDS=0
|
||||||
@@ -71,11 +77,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
DECOUPLED_MODE=True # True for streaming, False for offline
|
DECOUPLED_MODE=True # True for streaming, False for offline
|
||||||
|
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
|
||||||
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
||||||
|
if [ $use_spk2info_cache == "False" ]; then
|
||||||
|
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
|
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
@@ -94,7 +101,7 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
echo "Running benchmark client grpc"
|
echo "Running benchmark client grpc"
|
||||||
num_task=1
|
num_task=4
|
||||||
|
|
||||||
mode=streaming
|
mode=streaming
|
||||||
BLS_INSTANCE_NUM=4
|
BLS_INSTANCE_NUM=4
|
||||||
@@ -104,6 +111,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
--model-name cosyvoice2 \
|
--model-name cosyvoice2 \
|
||||||
--num-tasks $num_task \
|
--num-tasks $num_task \
|
||||||
--mode $mode \
|
--mode $mode \
|
||||||
|
--use-spk2info-cache $use_spk2info_cache \
|
||||||
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||||
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
|
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
|
||||||
fi
|
fi
|
||||||
|
|||||||
Reference in New Issue
Block a user