mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2023 Nvidia (authors: Yuekai Zhang)
|
||||
# 2023 Recurrent.ai (authors: Songtao Shi)
|
||||
@@ -76,6 +75,7 @@ class UserData:
|
||||
return self._first_chunk_time - self._start_time
|
||||
return None
|
||||
|
||||
|
||||
def callback(user_data, result, error):
|
||||
if user_data._first_chunk_time is None and not error:
|
||||
user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
||||
@@ -206,8 +206,11 @@ def get_args():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=["f5_tts", "spark_tts", "cosyvoice2"],
|
||||
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
||||
choices=[
|
||||
"f5_tts",
|
||||
"spark_tts",
|
||||
"cosyvoice2"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -273,6 +276,7 @@ def load_audio(wav_path, target_sample_rate=16000):
|
||||
waveform = resample(waveform, num_samples)
|
||||
return waveform, target_sample_rate
|
||||
|
||||
|
||||
def prepare_request_input_output(
|
||||
protocol_client, # Can be grpcclient_aio or grpcclient_sync
|
||||
waveform,
|
||||
@@ -329,6 +333,7 @@ def prepare_request_input_output(
|
||||
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def run_sync_streaming_inference(
|
||||
sync_triton_client: tritonclient.grpc.InferenceServerClient,
|
||||
model_name: str,
|
||||
@@ -493,7 +498,6 @@ async def send_streaming(
|
||||
else:
|
||||
print(f"{name}: Item {i} failed.")
|
||||
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
||||
except Exception as e:
|
||||
@@ -501,7 +505,6 @@ async def send_streaming(
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
finally: # Ensure client is closed
|
||||
if sync_triton_client:
|
||||
try:
|
||||
@@ -510,10 +513,10 @@ async def send_streaming(
|
||||
except Exception as e:
|
||||
print(f"{name}: Error closing sync client: {e}")
|
||||
|
||||
|
||||
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
|
||||
return total_duration, latency_data
|
||||
|
||||
|
||||
async def send(
|
||||
manifest_item_list: list,
|
||||
name: str,
|
||||
@@ -605,6 +608,7 @@ def split_data(data, k):
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def main():
|
||||
args = get_args()
|
||||
url = f"{args.server_addr}:{args.server_port}"
|
||||
@@ -713,7 +717,6 @@ async def main():
|
||||
else:
|
||||
print("Warning: A task returned None, possibly due to an error.")
|
||||
|
||||
|
||||
if total_duration == 0:
|
||||
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
|
||||
rtf = float('inf')
|
||||
|
||||
@@ -29,6 +29,7 @@ import json
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@@ -67,9 +68,10 @@ def get_args():
|
||||
type=str,
|
||||
default="spark_tts",
|
||||
choices=[
|
||||
"f5_tts", "spark_tts", "cosyvoice2"
|
||||
],
|
||||
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
||||
"f5_tts",
|
||||
"spark_tts",
|
||||
"cosyvoice2"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -80,6 +82,7 @@ def get_args():
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_request(
|
||||
waveform,
|
||||
reference_text,
|
||||
@@ -97,7 +100,7 @@ def prepare_request(
|
||||
1,
|
||||
padding_duration
|
||||
* sample_rate
|
||||
* ((int(duration) // padding_duration) + 1),
|
||||
* ((int(len(waveform) / sample_rate) // padding_duration) + 1),
|
||||
),
|
||||
dtype=np.float32,
|
||||
)
|
||||
@@ -139,6 +142,7 @@ def prepare_request(
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
server_url = args.server_url
|
||||
|
||||
@@ -35,6 +35,7 @@ import s3tokenizer
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for audio tokenization.
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ import onnxruntime
|
||||
|
||||
from matcha.utils.audio import mel_spectrogram
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for Spark TTS.
|
||||
|
||||
@@ -187,7 +188,12 @@ class TritonPythonModel:
|
||||
|
||||
return prompt_speech_tokens
|
||||
|
||||
def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, target_speech_tokens: torch.Tensor) -> torch.Tensor:
|
||||
def forward_token2wav(
|
||||
self,
|
||||
prompt_speech_tokens: torch.Tensor,
|
||||
prompt_speech_feat: torch.Tensor,
|
||||
prompt_spk_embedding: torch.Tensor,
|
||||
target_speech_tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass through the vocoder component.
|
||||
|
||||
Args:
|
||||
@@ -240,9 +246,20 @@ class TritonPythonModel:
|
||||
embedding = torch.tensor([embedding]).to(self.device).half()
|
||||
return embedding
|
||||
|
||||
|
||||
def _extract_speech_feat(self, speech):
|
||||
speech_feat = mel_spectrogram(speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze(dim=0).transpose(0, 1).to(self.device)
|
||||
speech_feat = mel_spectrogram(
|
||||
speech,
|
||||
n_fft=1920,
|
||||
num_mels=80,
|
||||
sampling_rate=24000,
|
||||
hop_size=480,
|
||||
win_size=1920,
|
||||
fmin=0,
|
||||
fmax=8000).squeeze(
|
||||
dim=0).transpose(
|
||||
0,
|
||||
1).to(
|
||||
self.device)
|
||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||
return speech_feat
|
||||
|
||||
@@ -267,7 +284,6 @@ class TritonPythonModel:
|
||||
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
||||
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
|
||||
|
||||
|
||||
wav_tensor = wav.as_numpy()
|
||||
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
|
||||
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
|
||||
@@ -311,7 +327,7 @@ class TritonPythonModel:
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||
response_sender.send(inference_response)
|
||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||
self.logger.log_info(f"send tritonserver_response_complete_final to end")
|
||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||
else:
|
||||
generated_ids = next(generated_ids_iter)
|
||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
|
||||
|
||||
@@ -44,6 +44,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
|
||||
|
||||
class CosyVoice2:
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||
@@ -66,6 +67,7 @@ class CosyVoice2:
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
|
||||
|
||||
class CosyVoice2Model:
|
||||
|
||||
def __init__(self,
|
||||
@@ -109,6 +111,7 @@ class CosyVoice2Model:
|
||||
input_names = ["x", "mask", "mu", "cond"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
"""Triton Python model for vocoder.
|
||||
|
||||
@@ -137,7 +140,6 @@ class TritonPythonModel:
|
||||
|
||||
logger.info("Token2Wav initialized successfully")
|
||||
|
||||
|
||||
def execute(self, requests):
|
||||
"""Execute inference on the batched requests.
|
||||
|
||||
@@ -191,7 +193,3 @@ class TritonPythonModel:
|
||||
responses.append(inference_response)
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -35,8 +35,7 @@ def parse_arguments():
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
||||
help=
|
||||
"The data type for the model weights and activations if not quantized. "
|
||||
help="The data type for the model weights and activations if not quantized. "
|
||||
"If 'auto', the data type is automatically inferred from the source model; "
|
||||
"however, if the source dtype is float32, it is converted to float16.")
|
||||
parser.add_argument(
|
||||
@@ -49,8 +48,7 @@ def parse_arguments():
|
||||
'--disable_weight_only_quant_plugin',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
|
||||
help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -60,16 +58,14 @@ def parse_arguments():
|
||||
nargs='?',
|
||||
default='int8',
|
||||
choices=['int8', 'int4', 'int4_gptq'],
|
||||
help=
|
||||
'Define the precision for the weights when using weight-only quantization.'
|
||||
help='Define the precision for the weights when using weight-only quantization.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--calib_dataset',
|
||||
type=str,
|
||||
default='ccdv/cnn_dailymail',
|
||||
help=
|
||||
"The huggingface dataset name or the local directory of the dataset for calibration."
|
||||
help="The huggingface dataset name or the local directory of the dataset for calibration."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoothquant",
|
||||
@@ -83,31 +79,27 @@ def parse_arguments():
|
||||
'--per_channel',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
||||
help='By default, we use a single static scaling factor for the GEMM\'s result. '
|
||||
'per_channel instead uses a different static scaling factor for each channel. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--per_token',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||||
help='By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--int8_kv_cache',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||||
help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_group',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||||
help='By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||||
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
||||
'The flag is built for GPTQ/AWQ quantization.')
|
||||
|
||||
@@ -121,16 +113,14 @@ def parse_arguments():
|
||||
'--use_parallel_embedding',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
||||
help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--embedding_sharding_dim',
|
||||
type=int,
|
||||
default=0,
|
||||
choices=[0, 1],
|
||||
help=
|
||||
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
||||
help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
||||
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
||||
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
||||
)
|
||||
@@ -147,15 +137,13 @@ def parse_arguments():
|
||||
'--moe_tp_size',
|
||||
type=int,
|
||||
default=-1,
|
||||
help=
|
||||
'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
|
||||
help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--moe_ep_size',
|
||||
type=int,
|
||||
default=-1,
|
||||
help=
|
||||
'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
|
||||
help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -249,7 +237,7 @@ def convert_and_save_hf(args):
|
||||
trust_remote_code=True)
|
||||
quant_config, override_fields = update_quant_config_from_hf(
|
||||
quant_config, hf_config, override_fields)
|
||||
except:
|
||||
except BaseException:
|
||||
logger.warning("AutoConfig cannot load the huggingface config.")
|
||||
|
||||
if args.smoothquant is not None or args.int8_kv_cache:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#! /usr/bin/env python3
|
||||
# /usr/bin/env python3
|
||||
from argparse import ArgumentParser
|
||||
from string import Template
|
||||
|
||||
@@ -59,8 +59,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("file_path", help="path of the .pbtxt to modify")
|
||||
parser.add_argument(
|
||||
"substitutions",
|
||||
help=
|
||||
"substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
|
||||
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
|
||||
)
|
||||
parser.add_argument("--in_place",
|
||||
"-i",
|
||||
|
||||
@@ -46,7 +46,6 @@ def parse_arguments(args=None):
|
||||
parser.add_argument('--top_k', type=int, default=50)
|
||||
parser.add_argument('--top_p', type=float, default=0.95)
|
||||
|
||||
|
||||
return parser.parse_args(args=args)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user