mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -48,9 +48,11 @@ import hashlib
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
||||
"""
|
||||
Generates a unique ID for a torch.Tensor.
|
||||
@@ -65,6 +67,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for vocoder.
|
||||
|
||||
@@ -114,7 +117,6 @@ class TritonPythonModel:
|
||||
|
||||
request_id = request.request_id()
|
||||
|
||||
|
||||
wav_array = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav").as_numpy()
|
||||
wav_len = pb_utils.get_input_tensor_by_name(
|
||||
@@ -125,7 +127,10 @@ class TritonPythonModel:
|
||||
|
||||
spk_id = get_spk_id_from_prompt_audio(wav)
|
||||
|
||||
audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
|
||||
audio_hat = self.token2wav_model.forward_streaming(
|
||||
target_speech_tokens, finalize, request_id=request_id,
|
||||
speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
|
||||
)
|
||||
|
||||
outputs = []
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ import numpy as np
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
|
||||
|
||||
def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
|
||||
def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
|
||||
"""perform fade_in_out in tensor style
|
||||
"""
|
||||
mel_overlap_len = int(window.shape[0] / 2)
|
||||
@@ -45,6 +45,7 @@ def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torc
|
||||
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||
return fade_in_mel
|
||||
|
||||
|
||||
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
|
||||
import tensorrt as trt
|
||||
logging.info("Converting onnx to trt...")
|
||||
@@ -90,6 +91,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
@@ -108,6 +110,7 @@ class TrtContextWrapper:
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
|
||||
|
||||
class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
|
||||
super().__init__()
|
||||
@@ -131,27 +134,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
|
||||
providers=["CPUExecutionProvider"])
|
||||
self.spk_model = onnxruntime.InferenceSession(
|
||||
f"{model_dir}/campplus.onnx", sess_options=option,
|
||||
providers=["CPUExecutionProvider"])
|
||||
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
|
||||
|
||||
gpu="l20"
|
||||
gpu = "l20"
|
||||
if enable_trt:
|
||||
if streaming:
|
||||
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
|
||||
1,
|
||||
self.dtype, streaming)
|
||||
self.load_trt(
|
||||
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
|
||||
1,
|
||||
self.dtype, streaming
|
||||
)
|
||||
else:
|
||||
self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||
1,
|
||||
self.dtype)
|
||||
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||
f'{model_dir}/campplus.onnx',
|
||||
1,
|
||||
False)
|
||||
|
||||
self.load_trt(
|
||||
f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
|
||||
f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
|
||||
1,
|
||||
self.dtype
|
||||
)
|
||||
self.load_spk_trt(
|
||||
f'{model_dir}/campplus.{gpu}.fp32.trt',
|
||||
f'{model_dir}/campplus.onnx',
|
||||
1,
|
||||
False
|
||||
)
|
||||
|
||||
self.streaming_flow_cache = {}
|
||||
self.speaker_cache = {}
|
||||
@@ -215,7 +224,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
opt_batch_size = 2
|
||||
max_batch_size = 16
|
||||
if streaming:
|
||||
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
|
||||
opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
|
||||
trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
|
||||
del self.flow.decoder.estimator
|
||||
@@ -228,13 +237,27 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
|
||||
if streaming:
|
||||
min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
|
||||
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)]
|
||||
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)]
|
||||
opt_shape = [
|
||||
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
|
||||
(opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
|
||||
(16, opt_batch_size * 2, 8, 100, 128)
|
||||
]
|
||||
max_shape = [
|
||||
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
|
||||
(max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
|
||||
(16, max_batch_size * 2, 8, 1000, 128)
|
||||
]
|
||||
input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
|
||||
else:
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
|
||||
opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
|
||||
max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
|
||||
opt_shape = [
|
||||
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
|
||||
(opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
|
||||
]
|
||||
max_shape = [
|
||||
(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
|
||||
(max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
|
||||
]
|
||||
input_names = ["x", "mask", "mu", "cond", "t", "spks"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
@@ -279,11 +302,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
mel_len = mel.shape[0]
|
||||
prompt_mels_for_flow.append(mel)
|
||||
prompt_mels_lens_for_flow.append(mel_len)
|
||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
|
||||
prompt_mels_for_flow, batch_first=True, padding_value=0
|
||||
) # [B, T', num_mels=80]
|
||||
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||
|
||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
|
||||
generated_speech_tokens_list: list[list[int]],
|
||||
prompt_mels_for_flow: torch.Tensor,
|
||||
prompt_mels_lens_for_flow: torch.Tensor,
|
||||
spk_emb_for_flow: torch.Tensor):
|
||||
batch_size = prompt_mels_for_flow.shape[0]
|
||||
flow_inputs = []
|
||||
flow_inputs_lens = []
|
||||
@@ -311,7 +340,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
generated_wavs.append(wav)
|
||||
return generated_wavs
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
@@ -320,7 +348,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
|
||||
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||
generated_mels, generated_mels_lens = self.forward_flow(
|
||||
prompt_speech_tokens_list, generated_speech_tokens_list,
|
||||
prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
)
|
||||
|
||||
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||
return generated_wavs
|
||||
@@ -337,7 +368,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
|
||||
|
||||
def get_prompt_audio_cache_for_streaming_tts(
|
||||
self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
):
|
||||
@@ -356,7 +386,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
# Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
|
||||
return new_cache
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_streaming(
|
||||
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
||||
@@ -379,9 +408,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
if request_id not in self.streaming_flow_cache:
|
||||
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
||||
self.hift_cache_dict[request_id] = dict(
|
||||
mel = torch.zeros(1, 80, 0, device='cuda'),
|
||||
source = torch.zeros(1, 1, 0, device='cuda'),
|
||||
speech = torch.zeros(1, 0, device='cuda'),
|
||||
mel=torch.zeros(1, 80, 0, device='cuda'),
|
||||
source=torch.zeros(1, 1, 0, device='cuda'),
|
||||
speech=torch.zeros(1, 0, device='cuda'),
|
||||
)
|
||||
|
||||
current_request_cache = self.streaming_flow_cache[request_id]
|
||||
@@ -389,7 +418,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
|
||||
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
||||
|
||||
|
||||
chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
|
||||
token=generated_speech_tokens,
|
||||
spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
|
||||
@@ -400,15 +428,12 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
|
||||
self.streaming_flow_cache[request_id] = new_streaming_flow_cache
|
||||
|
||||
|
||||
if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
|
||||
self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
|
||||
], dim=4)
|
||||
|
||||
|
||||
|
||||
hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
|
||||
hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
|
||||
hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
|
||||
@@ -422,9 +447,9 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
|
||||
# update vocoder cache
|
||||
self.hift_cache_dict[request_id] = dict(
|
||||
mel = mel[..., -self.mel_cache_len:].clone().detach(),
|
||||
source = source[:, :, -self.source_cache_len:].clone().detach(),
|
||||
speech = speech[:, -self.source_cache_len:].clone().detach(),
|
||||
mel=mel[..., -self.mel_cache_len:].clone().detach(),
|
||||
source=source[:, :, -self.source_cache_len:].clone().detach(),
|
||||
speech=speech[:, -self.source_cache_len:].clone().detach(),
|
||||
)
|
||||
if not last_chunk:
|
||||
speech = speech[:, :-self.source_cache_len]
|
||||
@@ -436,6 +461,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
|
||||
return speech
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
@@ -447,6 +473,7 @@ def collate_fn(batch):
|
||||
|
||||
return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-trt", action="store_true")
|
||||
@@ -457,6 +484,7 @@ def get_args():
|
||||
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
|
||||
@@ -466,22 +494,17 @@ if __name__ == "__main__":
|
||||
|
||||
dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
|
||||
|
||||
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||
|
||||
|
||||
for epoch in range(args.warmup):
|
||||
start_time = time.time()
|
||||
|
||||
for batch in data_loader:
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||
|
||||
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
|
||||
for id, wav in zip(ids, generated_wavs):
|
||||
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||
|
||||
end_time = time.time()
|
||||
epoch_time = end_time - start_time
|
||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||
|
||||
Reference in New Issue
Block a user