mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
Merge pull request #1566 from yuekaizhang/streaming
[runtime: TRT-LLM] support prompt audio cache & offline inference mode
This commit is contained in:
11
README.md
11
README.md
@@ -246,6 +246,17 @@ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /o
|
||||
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||
```
|
||||
|
||||
#### Using Nvidia TensorRT-LLM for deployment
|
||||
|
||||
Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
|
||||
To quick start:
|
||||
|
||||
``` sh
|
||||
cd runtime/triton_trtllm
|
||||
docker compose up -d
|
||||
```
|
||||
For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
|
||||
|
||||
## Discussion & Communication
|
||||
|
||||
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
## Serving CosyVoice with NVIDIA Triton Inference Server
|
||||
## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM
|
||||
|
||||
Contributed by Yuekai Zhang (NVIDIA).
|
||||
|
||||
@@ -41,6 +41,7 @@ bash run.sh <start_stage> <stop_stage> [service_type]
|
||||
- **Stage 3**: Launches the Triton Inference Server.
|
||||
- **Stage 4**: Runs the single-utterance HTTP client for testing.
|
||||
- **Stage 5**: Runs the gRPC benchmark client.
|
||||
- **Stage 6**: Runs the offline inference benchmark test.
|
||||
|
||||
### Export Models and Launch Server
|
||||
|
||||
@@ -59,7 +60,7 @@ Sends a single HTTP inference request. This is intended for testing the offline
|
||||
bash run.sh 4 4
|
||||
```
|
||||
|
||||
### Benchmark with a Dataset
|
||||
### Benchmark with client-server mode
|
||||
|
||||
To benchmark the running Triton server, pass `streaming` or `offline` as the third argument:
|
||||
```sh
|
||||
@@ -71,23 +72,57 @@ bash run.sh 5 5 # [streaming|offline]
|
||||
> [!TIP]
|
||||
> It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up.
|
||||
|
||||
### Benchmark with offline inference mode
|
||||
For offline inference mode benchmark, please check the below command:
|
||||
```sh
|
||||
# install FlashCosyVoice for token2wav batching
|
||||
# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
|
||||
# cd /workspace/FlashCosyVoice
|
||||
# pip install -e .
|
||||
# cd -
|
||||
# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
|
||||
|
||||
bash run.sh 6 6
|
||||
|
||||
# You can also switch to huggingface backend by setting backend=hf
|
||||
```
|
||||
|
||||
|
||||
### Benchmark Results
|
||||
The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio):
|
||||
|
||||
**Streaming TTS (First Chunk Latency)**
|
||||
**Client-Server Mode: Streaming TTS (First Chunk Latency)**
|
||||
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||
|---|---|---|---|---|
|
||||
| Streaming, Decoupled=True | 1 | 220.43 | 218.07 | 0.1237 |
|
||||
| Streaming, Decoupled=True | 2 | 476.97 | 369.25 | 0.1022 |
|
||||
| Streaming, Decoupled=True | 4 | 1107.34 | 1243.75| 0.0922 |
|
||||
| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
|
||||
| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
|
||||
| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
|
||||
| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
|
||||
| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
|
||||
| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
|
||||
|
||||
**Offline TTS (Full Sentence Latency)**
|
||||
> If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E).
|
||||
|
||||
**Client-Server Mode: Offline TTS (Full Sentence Latency)**
|
||||
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||
|---|---|---|---|---|---|
|
||||
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
|
||||
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
|
||||
| Offline, Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
|
||||
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
|
||||
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
|
||||
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
|
||||
|
||||
**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM**
|
||||
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|
||||
|---------|------------|------------------|-----------------------|--|
|
||||
| HF | 1 | 39.26 | 44.31 | 0.2494 |
|
||||
| HF | 2 | 30.54 | 35.62 | 0.2064 |
|
||||
| HF | 4 | 18.63 | 23.90 | 0.1421 |
|
||||
| HF | 8 | 11.22 | 16.45 | 0.0947 |
|
||||
| HF | 16 | 8.42 | 13.78 | 0.0821 |
|
||||
| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
|
||||
| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
|
||||
| TRTLLM | 4 | 4.89 | 9.38 | 0.0539 |
|
||||
| TRTLLM | 8 | 2.92 | 7.23 | 0.0418 |
|
||||
| TRTLLM | 16 | 2.01 | 6.63 | 0.0386 |
|
||||
### OpenAI-Compatible Server
|
||||
|
||||
To launch an OpenAI-compatible API service, run the following commands:
|
||||
|
||||
@@ -257,7 +257,13 @@ def get_args():
|
||||
default=0.1,
|
||||
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
||||
)
|
||||
# --- End Added arguments ---
|
||||
|
||||
parser.add_argument(
|
||||
"--use-spk2info-cache",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Use spk2info cache for reference audio.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -283,7 +289,8 @@ def prepare_request_input_output(
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=16000,
|
||||
padding_duration: int = None # Optional padding for offline mode
|
||||
padding_duration: int = None, # Optional padding for offline mode
|
||||
use_spk2info_cache: bool = False
|
||||
):
|
||||
"""Prepares inputs for Triton inference (offline or streaming)."""
|
||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||
@@ -330,7 +337,8 @@ def prepare_request_input_output(
|
||||
inputs[3].set_data_from_numpy(input_data_numpy)
|
||||
|
||||
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
||||
|
||||
if use_spk2info_cache:
|
||||
inputs = inputs[-1:]
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
@@ -453,6 +461,7 @@ async def send_streaming(
|
||||
save_sample_rate: int = 16000,
|
||||
chunk_overlap_duration: float = 0.1,
|
||||
padding_duration: int = None,
|
||||
use_spk2info_cache: bool = False,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
@@ -478,7 +487,8 @@ async def send_streaming(
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate,
|
||||
padding_duration=padding_duration
|
||||
padding_duration=padding_duration,
|
||||
use_spk2info_cache=use_spk2info_cache
|
||||
)
|
||||
request_id = str(uuid.uuid4())
|
||||
user_data = UserData()
|
||||
@@ -534,6 +544,7 @@ async def send(
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
use_spk2info_cache: bool = False,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
@@ -552,7 +563,8 @@ async def send(
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate,
|
||||
padding_duration=padding_duration
|
||||
padding_duration=padding_duration,
|
||||
use_spk2info_cache=use_spk2info_cache
|
||||
)
|
||||
sequence_id = 100000000 + i + task_id * 10
|
||||
start = time.time()
|
||||
@@ -691,6 +703,7 @@ async def main():
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=1,
|
||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||
use_spk2info_cache=args.use_spk2info_cache,
|
||||
)
|
||||
)
|
||||
elif args.mode == "streaming":
|
||||
@@ -706,6 +719,7 @@ async def main():
|
||||
padding_duration=10,
|
||||
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||
chunk_overlap_duration=args.chunk_overlap_duration,
|
||||
use_spk2info_cache=args.use_spk2info_cache,
|
||||
)
|
||||
)
|
||||
# --- End Task Creation ---
|
||||
|
||||
@@ -43,6 +43,7 @@ import torchaudio
|
||||
|
||||
from matcha.utils.audio import mel_spectrogram
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
@@ -81,6 +82,12 @@ class TritonPythonModel:
|
||||
self.flow_pre_lookahead_len = 3
|
||||
self.token_hop_len = 15
|
||||
|
||||
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
|
||||
if not os.path.exists(spk_info_path):
|
||||
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
|
||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||
self.default_spk_info = spk_info["001"]
|
||||
|
||||
def forward_llm(self, input_ids):
|
||||
"""
|
||||
Prepares the response from the language model based on the provided
|
||||
@@ -220,11 +227,11 @@ class TritonPythonModel:
|
||||
|
||||
def forward_token2wav(
|
||||
self,
|
||||
prompt_speech_tokens: torch.Tensor,
|
||||
prompt_speech_feat: torch.Tensor,
|
||||
prompt_spk_embedding: torch.Tensor,
|
||||
target_speech_tokens: torch.Tensor,
|
||||
request_id: str,
|
||||
prompt_speech_tokens: torch.Tensor = None,
|
||||
prompt_speech_feat: torch.Tensor = None,
|
||||
prompt_spk_embedding: torch.Tensor = None,
|
||||
token_offset: int = None,
|
||||
finalize: bool = None) -> torch.Tensor:
|
||||
"""Forward pass through the vocoder component.
|
||||
@@ -238,12 +245,9 @@ class TritonPythonModel:
|
||||
Returns:
|
||||
Generated waveform tensor
|
||||
"""
|
||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||
|
||||
inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
||||
inputs_tensor = [target_speech_tokens_tensor]
|
||||
|
||||
if token_offset is not None:
|
||||
assert finalize is not None
|
||||
@@ -252,6 +256,13 @@ class TritonPythonModel:
|
||||
inputs_tensor.append(token_offset_tensor)
|
||||
inputs_tensor.append(finalize_tensor)
|
||||
|
||||
if prompt_spk_embedding is not None:
|
||||
assert prompt_speech_feat is not None
|
||||
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
|
||||
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
|
||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
|
||||
|
||||
# Create and execute inference request
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name='token2wav',
|
||||
@@ -318,25 +329,30 @@ class TritonPythonModel:
|
||||
request_id = request.request_id()
|
||||
# Extract input tensors
|
||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||
|
||||
# Process reference audio through audio tokenizer
|
||||
if wav is not None:
|
||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||
|
||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||
wav_tensor = wav.as_numpy()
|
||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||
|
||||
wav_tensor = wav.as_numpy()
|
||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||
speech_feat = self._extract_speech_feat(prompt_speech_resample)
|
||||
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
|
||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||
|
||||
flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode('utf-8')
|
||||
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||
else:
|
||||
# using pre-cached reference text
|
||||
reference_text = self.default_spk_info["prompt_text"]
|
||||
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
|
||||
prompt_speech_feat = None
|
||||
prompt_spk_embedding = None
|
||||
|
||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||
target_text = target_text[0][0].decode('utf-8')
|
||||
@@ -350,7 +366,6 @@ class TritonPythonModel:
|
||||
|
||||
# Generate semantic tokens with LLM
|
||||
generated_ids_iter = self.forward_llm(input_ids)
|
||||
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
|
||||
|
||||
if self.decoupled:
|
||||
response_sender = request.get_response_sender()
|
||||
@@ -380,8 +395,9 @@ class TritonPythonModel:
|
||||
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
|
||||
sub_tts_speech = self.forward_token2wav(
|
||||
prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding,
|
||||
this_tts_speech_token, request_id, token_offset, False)
|
||||
this_tts_speech_token, request_id, prompt_speech_tokens,
|
||||
prompt_speech_feat, prompt_spk_embedding, token_offset, False
|
||||
)
|
||||
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
@@ -414,7 +430,7 @@ class TritonPythonModel:
|
||||
time.sleep(0.02)
|
||||
|
||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
|
||||
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
@@ -428,7 +444,7 @@ class TritonPythonModel:
|
||||
if generated_ids is None or len(generated_ids) == 0:
|
||||
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
||||
|
||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
|
||||
audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
|
||||
|
||||
# Prepare response
|
||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
|
||||
@@ -37,16 +37,19 @@ input [
|
||||
name: "reference_wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "reference_wav_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "reference_text"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "target_text"
|
||||
|
||||
@@ -187,6 +187,12 @@ class TritonPythonModel:
|
||||
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
|
||||
)
|
||||
|
||||
spk_info_path = os.path.join(model_dir, "spk2info.pt")
|
||||
if not os.path.exists(spk_info_path):
|
||||
raise ValueError(f"spk2info.pt not found in {model_dir}")
|
||||
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
|
||||
self.default_spk_info = spk_info["001"]
|
||||
|
||||
logger.info("Token2Wav initialized successfully")
|
||||
|
||||
def execute(self, requests):
|
||||
@@ -202,17 +208,23 @@ class TritonPythonModel:
|
||||
# Process each request in batch
|
||||
for request in requests:
|
||||
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
|
||||
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
|
||||
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
|
||||
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
|
||||
|
||||
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
|
||||
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
|
||||
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
||||
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
||||
|
||||
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
|
||||
if prompt_speech_tokens_tensor is not None:
|
||||
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
|
||||
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
|
||||
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
|
||||
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
|
||||
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
|
||||
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
|
||||
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
else:
|
||||
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
|
||||
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
|
||||
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
|
||||
|
||||
# shift the speech tokens according to the original vocab size
|
||||
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||
|
||||
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
||||
|
||||
@@ -35,16 +35,19 @@ input [
|
||||
name: "prompt_speech_tokens"
|
||||
data_type: TYPE_INT32
|
||||
dims: [-1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "prompt_speech_feat"
|
||||
data_type: TYPE_FP16
|
||||
dims: [-1, 80]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "prompt_spk_embedding"
|
||||
data_type: TYPE_FP16
|
||||
dims: [-1]
|
||||
optional: true
|
||||
},
|
||||
{
|
||||
name: "token_offset"
|
||||
|
||||
563
runtime/triton_trtllm/offline_inference.py
Normal file
563
runtime/triton_trtllm/offline_inference.py
Normal file
@@ -0,0 +1,563 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Example Usage
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python3 offline_inference.py \
|
||||
--output-dir $output_dir \
|
||||
--llm-model-name-or-path $huggingface_model_local_dir \
|
||||
--token2wav-path $model_scope_model_local_dir \
|
||||
--backend $backend \
|
||||
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
|
||||
--engine-dir $trt_engines_dir \
|
||||
--split-name ${dataset} || exit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from tqdm import tqdm
|
||||
import soundfile as sf
|
||||
import s3tokenizer
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
from token2wav import CosyVoice2_Token2Wav
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
def extract_speech_ids(speech_tokens_str):
|
||||
"""Extract speech IDs from token strings like <|s_23456|>"""
|
||||
speech_ids = []
|
||||
for token_str in speech_tokens_str:
|
||||
if token_str.startswith('<|s_') and token_str.endswith('|>'):
|
||||
num_str = token_str[4:-2]
|
||||
num = int(num_str)
|
||||
speech_ids.append(num)
|
||||
else:
|
||||
print(f"Unexpected token: {token_str}")
|
||||
return speech_ids
|
||||
|
||||
|
||||
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||
speech_id_str = ""
|
||||
for token in cosy2_tokens:
|
||||
speech_id_str += f"<|s_{token}|>"
|
||||
return speech_id_str
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", required=True, type=str, help="dir to save result"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token2wav-batch-size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers", type=int, default=0, help="workers for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefetch", type=int, default=None, help="prefetch for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model-name-or-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="LLM model path (includes both model and tokenizer)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token2wav-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="CosyVoice2 token2wav model path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-text",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt text for CosyVoice2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-speech-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the prompt speech for CosyVoice2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="top p for sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="temperature for sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=50,
|
||||
help="top k for sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="hf",
|
||||
choices=["hf", "trtllm", "vllm"],
|
||||
help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--engine-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-free-gpu-memory-fraction",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
"""Simplified data collator for batch_size=1 processing"""
|
||||
collator_start_time = time.time()
|
||||
total_audio_processing_time = 0
|
||||
total_speech_tokenization_time = 0
|
||||
total_text_tokenization_time = 0
|
||||
|
||||
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
||||
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
||||
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||
prompt_text_after_apply_template_list = []
|
||||
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
|
||||
for _, item in enumerate(batch):
|
||||
audio_processing_start_time = time.time()
|
||||
prompt_text, target_text = (
|
||||
item["prompt_text"],
|
||||
item["target_text"],
|
||||
)
|
||||
prompt_text_list.append(prompt_text)
|
||||
full_text = prompt_text + target_text
|
||||
full_text_list.append(full_text)
|
||||
# remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
|
||||
puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\'']
|
||||
for p in puncts:
|
||||
if p in full_text:
|
||||
full_text = full_text.replace(p, '')
|
||||
print(f"removed {p} from {full_text}")
|
||||
|
||||
# get prompt audio for CosyVoice2 (convert to 16kHz)
|
||||
ref_audio_org, ref_sr = (
|
||||
item["prompt_audio"]["array"],
|
||||
item["prompt_audio"]["sampling_rate"],
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
|
||||
print(ref_audio_org.shape)
|
||||
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio = ref_audio_org
|
||||
|
||||
prompt_audio_list.append(ref_audio)
|
||||
audio_processing_end_time = time.time()
|
||||
total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
|
||||
|
||||
speech_tokenization_start_time = time.time()
|
||||
if "prompt_audio_cosy2_tokens" in item:
|
||||
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
|
||||
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
|
||||
else:
|
||||
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
|
||||
|
||||
if len(mels) > 0:
|
||||
mels, mels_lens = s3tokenizer.padding(mels)
|
||||
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
||||
for i in range(len(codes)):
|
||||
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
||||
speech_tokenization_end_time = time.time()
|
||||
total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
|
||||
|
||||
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
|
||||
text_tokenization_start_time = time.time()
|
||||
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
||||
# Create chat template for LLM generation
|
||||
chat = [
|
||||
{"role": "user", "content": full_text_list[i]},
|
||||
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
|
||||
]
|
||||
|
||||
assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
|
||||
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
chat,
|
||||
tokenize=True,
|
||||
return_tensors='pt',
|
||||
continue_final_message=True
|
||||
)
|
||||
input_ids_list.append(input_ids.squeeze(0))
|
||||
|
||||
prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
|
||||
|
||||
prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
|
||||
text_tokenization_end_time = time.time()
|
||||
total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
|
||||
|
||||
ids = [item["id"] for item in batch]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids_list,
|
||||
"ids": ids,
|
||||
"prompt_text": prompt_text_list,
|
||||
"prompt_audio_list": prompt_audio_list,
|
||||
"prompt_text_after_apply_template": prompt_text_after_apply_template_list,
|
||||
"audio_processing_time": total_audio_processing_time,
|
||||
"speech_tokenization_time": total_speech_tokenization_time,
|
||||
"text_tokenization_time": total_text_tokenization_time,
|
||||
}
|
||||
|
||||
|
||||
def init_distributed():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
print(
|
||||
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||
)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group("nccl")
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
local_rank, world_size, rank = 0, 1, 0
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||
|
||||
if args.backend == "hf":
|
||||
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
runner = None
|
||||
elif args.backend == "trtllm":
|
||||
if args.engine_dir is None:
|
||||
raise ValueError("--engine-dir is required when backend is 'trtllm'")
|
||||
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
model = None
|
||||
|
||||
runner_kwargs = dict(
|
||||
engine_dir=args.engine_dir,
|
||||
rank=runtime_rank,
|
||||
max_output_len=2048,
|
||||
enable_context_fmha_fp32_acc=False,
|
||||
max_batch_size=args.batch_size,
|
||||
max_input_len=512,
|
||||
kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
|
||||
cuda_graph_mode=False,
|
||||
gather_generation_logits=False,
|
||||
)
|
||||
|
||||
runner = ModelRunnerCpp.from_dir(**runner_kwargs)
|
||||
elif args.backend == "vllm":
|
||||
model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
|
||||
runner = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
token2wav_model = CosyVoice2_Token2Wav(
|
||||
model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
|
||||
)
|
||||
if args.prompt_speech_path:
|
||||
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
|
||||
else:
|
||||
prompt_speech_16k = None
|
||||
s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
|
||||
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
sampler = None
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch,
|
||||
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
|
||||
)
|
||||
for _ in range(3):
|
||||
print(f"Running {_} times")
|
||||
total_llm_time = 0
|
||||
total_token2wav_time = 0
|
||||
total_data_load_time = 0
|
||||
total_llm_post_processing_time = 0
|
||||
total_audio_save_time = 0
|
||||
total_audio_processing_time_in_collator = 0
|
||||
total_speech_tokenization_time_in_collator = 0
|
||||
total_text_tokenization_time_in_collator = 0
|
||||
total_audio_samples = 0
|
||||
start_time = time.time()
|
||||
total_steps = len(dataset)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||
|
||||
last_batch_end_time = time.time()
|
||||
for batch in dataloader:
|
||||
data_loaded_time = time.time()
|
||||
total_data_load_time += data_loaded_time - last_batch_end_time
|
||||
total_audio_processing_time_in_collator += batch["audio_processing_time"]
|
||||
total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
|
||||
total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
|
||||
with torch.no_grad():
|
||||
llm_start_time = time.time()
|
||||
if args.backend == "hf":
|
||||
input_ids_list = batch["input_ids"]
|
||||
if len(input_ids_list) == 1:
|
||||
input_ids = input_ids_list[0].unsqueeze(0)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||
input_ids_list_new = [
|
||||
torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
|
||||
for input_ids in input_ids_list
|
||||
]
|
||||
input_ids = torch.stack(input_ids_list_new)
|
||||
attention_mask = torch.zeros_like(input_ids)
|
||||
for i in range(len(input_ids_list)):
|
||||
attention_mask[i, :len(input_ids_list[i])] = 1
|
||||
|
||||
input_ids = input_ids.to(device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
max_new_tokens=2048,
|
||||
do_sample=True,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
repetition_penalty=1.1,
|
||||
top_k=args.top_k,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
elif args.backend == "trtllm":
|
||||
batch_input_ids = list(batch["input_ids"])
|
||||
input_lengths = [x.size(0) for x in batch_input_ids]
|
||||
|
||||
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
|
||||
print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
|
||||
outputs = runner.generate(
|
||||
batch_input_ids=batch_input_ids,
|
||||
max_new_tokens=2048,
|
||||
end_id=end_id,
|
||||
pad_id=end_id,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1,
|
||||
streaming=False,
|
||||
output_sequence_lengths=True,
|
||||
output_generation_logits=False,
|
||||
return_dict=True,
|
||||
return_all_generated_tokens=False
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
|
||||
num_output_sents, num_beams, _ = output_ids.size()
|
||||
assert num_beams == 1
|
||||
beam = 0
|
||||
batch_size = len(batch["input_ids"])
|
||||
num_return_sequences = num_output_sents // batch_size
|
||||
assert num_return_sequences == 1
|
||||
outputs = []
|
||||
for i in range(batch_size * num_return_sequences):
|
||||
batch_idx = i // num_return_sequences
|
||||
seq_idx = i % num_return_sequences
|
||||
output_begin = input_lengths[batch_idx]
|
||||
output_end = sequence_lengths[i][beam]
|
||||
outputs_i = output_ids[i][beam][:output_end].tolist()
|
||||
outputs.append(outputs_i)
|
||||
elif args.backend == "vllm":
|
||||
input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
repetition_penalty=1.1,
|
||||
max_tokens=2048,
|
||||
)
|
||||
outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
|
||||
print(outputs)
|
||||
for j, output in enumerate(outputs):
|
||||
outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
|
||||
|
||||
llm_end_time = time.time()
|
||||
total_llm_time += (llm_end_time - llm_start_time)
|
||||
|
||||
items_for_token_2wav = []
|
||||
for i in range(len(batch["ids"])):
|
||||
llm_post_processing_start_time = time.time()
|
||||
input_length = len(batch["input_ids"][i])
|
||||
generated_ids = outputs[i][input_length:]
|
||||
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
print(i, speech_ids)
|
||||
if len(speech_ids) == 0:
|
||||
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||
continue
|
||||
|
||||
if args.prompt_text is not None:
|
||||
current_prompt_text = args.prompt_text
|
||||
current_prompt_audio = prompt_speech_16k
|
||||
else:
|
||||
current_prompt_text = batch["prompt_text"][i]
|
||||
current_prompt_audio = batch["prompt_audio_list"][i]
|
||||
|
||||
llm_post_processing_end_time = time.time()
|
||||
total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
|
||||
if current_prompt_audio is not None:
|
||||
items_for_token_2wav.append({
|
||||
"speech_ids": speech_ids,
|
||||
"prompt_audio": current_prompt_audio.squeeze(0),
|
||||
"id": batch["ids"][i]
|
||||
})
|
||||
else:
|
||||
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||
|
||||
for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
|
||||
t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
|
||||
if not t2w_batch:
|
||||
continue
|
||||
|
||||
t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
|
||||
t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
|
||||
t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
|
||||
t2w_ids = [item["id"] for item in t2w_batch]
|
||||
|
||||
token2wav_start_time = time.time()
|
||||
generated_wavs = token2wav_model(
|
||||
t2w_generated_speech_tokens_list,
|
||||
t2w_prompt_audios_list,
|
||||
t2w_prompt_audios_sample_rate,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
token2wav_end_time = time.time()
|
||||
total_token2wav_time += (token2wav_end_time - token2wav_start_time)
|
||||
|
||||
audio_save_start_time = time.time()
|
||||
for j, audio_hat in enumerate(generated_wavs):
|
||||
generated_wave = audio_hat.squeeze().cpu().numpy()
|
||||
total_audio_samples += len(generated_wave)
|
||||
target_sample_rate = 24000
|
||||
|
||||
utt = t2w_ids[j]
|
||||
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
|
||||
print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
|
||||
audio_save_end_time = time.time()
|
||||
total_audio_save_time += audio_save_end_time - audio_save_start_time
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * len(batch["ids"]))
|
||||
|
||||
last_batch_end_time = time.time()
|
||||
if rank == 0:
|
||||
progress_bar.close()
|
||||
end_time = time.time()
|
||||
target_sample_rate = 24000
|
||||
total_audio_duration_seconds = total_audio_samples / target_sample_rate
|
||||
|
||||
log_file_path = os.path.join(args.output_dir, "log.txt")
|
||||
with open(log_file_path, 'w') as f:
|
||||
args_dict = vars(args)
|
||||
log_data = {
|
||||
"args": args_dict,
|
||||
"data_load_time_seconds": total_data_load_time,
|
||||
"audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
|
||||
"speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
|
||||
"text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
|
||||
"llm_time_seconds": total_llm_time,
|
||||
"llm_post_processing_time_seconds": total_llm_post_processing_time,
|
||||
"token2wav_time_seconds": total_token2wav_time,
|
||||
"audio_save_time_seconds": total_audio_save_time,
|
||||
"total_audio_duration_seconds": total_audio_duration_seconds,
|
||||
"pipeline_time_seconds": end_time - start_time,
|
||||
}
|
||||
print(log_data)
|
||||
f.write(json.dumps(log_data, indent=4))
|
||||
print(f"Metrics logged to {log_file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
if args.backend == "vllm":
|
||||
from vllm import LLM, SamplingParams
|
||||
elif args.backend == "trtllm":
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.runtime import ModelRunnerCpp
|
||||
elif args.backend == "hf":
|
||||
from transformers import AutoModelForCausalLM
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
main(args)
|
||||
@@ -15,6 +15,8 @@ trt_engines_dir=./trt_engines_${trt_dtype}
|
||||
|
||||
model_repo=./model_repo_cosyvoice2
|
||||
|
||||
use_spk2info_cache=False
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
echo "Cloning CosyVoice"
|
||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
|
||||
@@ -25,8 +27,11 @@ fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
echo "Downloading CosyVoice2-0.5B"
|
||||
# see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
|
||||
huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
|
||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
|
||||
# download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
|
||||
wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
|
||||
fi
|
||||
|
||||
|
||||
@@ -57,10 +62,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
cosyvoice2_dir="cosyvoice2"
|
||||
|
||||
cp -r ./model_repo/${cosyvoice2_dir} $model_repo
|
||||
cp -r ./model_repo/audio_tokenizer $model_repo
|
||||
cp -r ./model_repo/tensorrt_llm $model_repo
|
||||
cp -r ./model_repo/token2wav $model_repo
|
||||
cp -r ./model_repo/speaker_embedding $model_repo
|
||||
if [ $use_spk2info_cache == "False" ]; then
|
||||
cp -r ./model_repo/audio_tokenizer $model_repo
|
||||
cp -r ./model_repo/speaker_embedding $model_repo
|
||||
fi
|
||||
|
||||
ENGINE_PATH=$trt_engines_dir
|
||||
MAX_QUEUE_DELAY_MICROSECONDS=0
|
||||
@@ -71,11 +78,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
DECOUPLED_MODE=True # True for streaming, False for offline
|
||||
|
||||
python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
||||
|
||||
if [ $use_spk2info_cache == "False" ]; then
|
||||
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
@@ -94,7 +102,7 @@ fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
echo "Running benchmark client grpc"
|
||||
num_task=1
|
||||
num_task=4
|
||||
|
||||
mode=streaming
|
||||
BLS_INSTANCE_NUM=4
|
||||
@@ -104,6 +112,31 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
--model-name cosyvoice2 \
|
||||
--num-tasks $num_task \
|
||||
--mode $mode \
|
||||
--use-spk2info-cache $use_spk2info_cache \
|
||||
--huggingface-dataset yuekai/seed_tts_cosy2 \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
|
||||
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
echo "stage 6: Offline inference benchmark"
|
||||
n_gpus=1
|
||||
datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
|
||||
backend=trtllm # hf, trtllm, vllm
|
||||
|
||||
batch_sizes=(16 8 4 2 1)
|
||||
token2wav_batch_size=1
|
||||
for batch_size in ${batch_sizes[@]}; do
|
||||
for dataset in ${datasets[@]}; do
|
||||
output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python3 offline_inference.py \
|
||||
--output-dir $output_dir \
|
||||
--llm-model-name-or-path $huggingface_model_local_dir \
|
||||
--token2wav-path $model_scope_model_local_dir \
|
||||
--backend $backend \
|
||||
--batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
|
||||
--engine-dir $trt_engines_dir \
|
||||
--split-name ${dataset} || exit 1
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
335
runtime/triton_trtllm/token2wav.py
Normal file
335
runtime/triton_trtllm/token2wav.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Example Usage
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python3 token2wav.py --enable-trt || exit 1
|
||||
"""
|
||||
import torch
|
||||
from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
|
||||
from flashcosyvoice.modules.hifigan import HiFTGenerator
|
||||
from flashcosyvoice.utils.audio import mel_spectrogram
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import onnxruntime
|
||||
import s3tokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import load_dataset
|
||||
import torchaudio
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import queue
|
||||
import time
|
||||
|
||||
|
||||
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
import tensorrt as trt
|
||||
logging.info("Converting onnx to trt...")
|
||||
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
profile = builder.create_optimization_profile()
|
||||
# load onnx model
|
||||
with open(onnx_model, "rb") as f:
|
||||
if not parser.parse(f.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
raise ValueError('failed to parse {}'.format(onnx_model))
|
||||
# set input shapes
|
||||
for i in range(len(trt_kwargs['input_names'])):
|
||||
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
||||
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
||||
# set input and output data type
|
||||
for i in range(network.num_inputs):
|
||||
input_tensor = network.get_input(i)
|
||||
input_tensor.dtype = tensor_dtype
|
||||
for i in range(network.num_outputs):
|
||||
output_tensor = network.get_output(i)
|
||||
output_tensor.dtype = tensor_dtype
|
||||
config.add_optimization_profile(profile)
|
||||
engine_bytes = builder.build_serialized_network(network, config)
|
||||
# save trt engine
|
||||
with open(trt_model, "wb") as f:
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
self.trt_engine = trt_engine
|
||||
self.device = device
|
||||
for _ in range(trt_concurrent):
|
||||
trt_context = trt_engine.create_execution_context()
|
||||
trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
|
||||
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||
self.trt_context_pool.put([trt_context, trt_stream])
|
||||
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.trt_context_pool.get(), self.trt_engine
|
||||
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
|
||||
|
||||
class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
|
||||
super().__init__()
|
||||
self.device_id = device_id
|
||||
self.device = f"cuda:{device_id}"
|
||||
|
||||
self.flow = CausalMaskedDiffWithXvec()
|
||||
self.flow.half()
|
||||
self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
|
||||
self.hift = HiFTGenerator()
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
|
||||
|
||||
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
|
||||
|
||||
gpu = "l20"
|
||||
if enable_trt:
|
||||
self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||
1,
|
||||
True)
|
||||
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||
f'{model_dir}/campplus.onnx',
|
||||
1,
|
||||
False)
|
||||
|
||||
def forward_spk_embedding(self, spk_feat):
|
||||
if isinstance(self.spk_model, onnxruntime.InferenceSession):
|
||||
return self.spk_model.run(
|
||||
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
||||
)[0].flatten().tolist()
|
||||
else:
|
||||
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
|
||||
# NOTE need to synchronize when switching stream
|
||||
with torch.cuda.device(self.device_id):
|
||||
torch.cuda.current_stream().synchronize()
|
||||
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
|
||||
batch_size = spk_feat.size(0)
|
||||
|
||||
with stream:
|
||||
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
|
||||
output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
|
||||
|
||||
data_ptrs = [spk_feat.contiguous().data_ptr(),
|
||||
output_tensor.contiguous().data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
|
||||
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.spk_model.release_estimator(spk_model, stream)
|
||||
|
||||
return output_tensor.cpu().numpy().flatten().tolist()
|
||||
|
||||
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
|
||||
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
|
||||
trt_kwargs = self.get_spk_trt_kwargs()
|
||||
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
|
||||
import tensorrt as trt
|
||||
with open(spk_model, 'rb') as f:
|
||||
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
|
||||
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_spk_trt_kwargs(self):
|
||||
min_shape = [(1, 4, 80)]
|
||||
opt_shape = [(1, 500, 80)]
|
||||
max_shape = [(1, 3000, 80)]
|
||||
input_names = ["input"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16)
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
||||
opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)]
|
||||
max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,),
|
||||
(max_batch_size * 2, 80)]
|
||||
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
|
||||
prompt_speech_tokens_list, prompt_speech_mels_list = [], []
|
||||
for audio in prompt_audios_list:
|
||||
assert len(audio.shape) == 1
|
||||
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
|
||||
prompt_speech_mels_list.append(log_mel)
|
||||
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
|
||||
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
|
||||
prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
|
||||
)
|
||||
for i in range(len(prompt_speech_tokens)):
|
||||
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
||||
prompt_speech_tokens_list.append(speech_tokens_i)
|
||||
return prompt_speech_tokens_list
|
||||
|
||||
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
||||
spk_emb_for_flow = []
|
||||
for audio in prompt_audios_list:
|
||||
assert len(audio.shape) == 1
|
||||
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
|
||||
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
|
||||
spk_emb = self.forward_spk_embedding(spk_feat)
|
||||
|
||||
spk_emb_for_flow.append(spk_emb)
|
||||
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
||||
return spk_emb_for_flow
|
||||
|
||||
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
||||
prompt_mels_for_flow = []
|
||||
prompt_mels_lens_for_flow = []
|
||||
for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
|
||||
assert len(audio.shape) == 1
|
||||
audio = audio.unsqueeze(0)
|
||||
if sample_rate != 24000:
|
||||
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
|
||||
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
|
||||
mel_len = mel.shape[0]
|
||||
prompt_mels_for_flow.append(mel)
|
||||
prompt_mels_lens_for_flow.append(mel_len)
|
||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
||||
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||
|
||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor,
|
||||
prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
||||
batch_size = prompt_mels_for_flow.shape[0]
|
||||
flow_inputs = []
|
||||
flow_inputs_lens = []
|
||||
for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
|
||||
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
|
||||
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
|
||||
|
||||
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
|
||||
flow_inputs_lens = torch.tensor(flow_inputs_lens)
|
||||
|
||||
with torch.amp.autocast(self.device, dtype=torch.float16):
|
||||
generated_mels, generated_mels_lens = self.flow(
|
||||
flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
|
||||
prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device),
|
||||
streaming=False, finalize=True
|
||||
)
|
||||
|
||||
return generated_mels, generated_mels_lens
|
||||
|
||||
def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
|
||||
batch_size = generated_mels.shape[0]
|
||||
generated_wavs = []
|
||||
for i in range(batch_size):
|
||||
mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
|
||||
wav, _ = self.hift(speech_feat=mel)
|
||||
generated_wavs.append(wav)
|
||||
return generated_wavs
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
):
|
||||
# assert all item in prompt_audios_sample_rate is 16000
|
||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||
|
||||
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
||||
|
||||
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||
|
||||
generated_mels, generated_mels_lens = self.forward_flow(
|
||||
prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||
|
||||
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||
|
||||
return generated_wavs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
for _, item in enumerate(batch):
|
||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||
prompt_audios_list.append(audio)
|
||||
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||
ids.append(item['id'])
|
||||
|
||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-trt", action="store_true")
|
||||
parser.add_argument("--model-dir", type=str, default="./CosyVoice2-0.5B")
|
||||
parser.add_argument("--batch-size", type=int, default=4)
|
||||
parser.add_argument("--output-dir", type=str, default="generated_wavs")
|
||||
parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
|
||||
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
||||
# mkdir output_dir if not exists
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
dataset_name = "yuekai/seed_tts_cosy2"
|
||||
|
||||
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||
|
||||
for _ in range(args.warmup):
|
||||
start_time = time.time()
|
||||
|
||||
for batch in data_loader:
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||
|
||||
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
for id, wav in zip(ids, generated_wavs):
|
||||
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||
|
||||
end_time = time.time()
|
||||
epoch_time = end_time - start_time
|
||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||
Reference in New Issue
Block a user