mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
fix lint
This commit is contained in:
@@ -19,12 +19,13 @@ import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
import torch
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='export your model for deployment')
|
||||
parser.add_argument('--model_dir',
|
||||
@@ -35,6 +36,7 @@ def get_args():
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
@@ -67,5 +69,6 @@ def main():
|
||||
script = torch.jit.optimize_for_inference(script)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -20,13 +20,13 @@ import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
import os
|
||||
import sys
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
import onnxruntime
|
||||
import random
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ def get_args():
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
@@ -89,7 +90,8 @@ def main():
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), sess_options=option, providers=providers)
|
||||
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
sess_options=option, providers=providers)
|
||||
|
||||
for _ in tqdm(range(10)):
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
||||
@@ -105,5 +107,6 @@ def main():
|
||||
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
||||
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -18,16 +18,15 @@ import argparse
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torchaudio
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from tqdm import tqdm
|
||||
from cosyvoice.cli.model import CosyVoiceModel
|
||||
|
||||
from cosyvoice.dataset.dataset import Dataset
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='inference with your model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
@@ -66,7 +65,8 @@ def main():
|
||||
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
||||
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
||||
|
||||
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
||||
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
||||
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
||||
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
||||
|
||||
del configs
|
||||
@@ -74,13 +74,11 @@ def main():
|
||||
fn = os.path.join(args.result_dir, 'wav.scp')
|
||||
f = open(fn, 'w')
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in tqdm(enumerate(test_data_loader)):
|
||||
for _, batch in tqdm(enumerate(test_data_loader)):
|
||||
utts = batch["utts"]
|
||||
assert len(utts) == 1, "inference mode only support batchsize 1"
|
||||
text = batch["text"]
|
||||
text_token = batch["text_token"].to(device)
|
||||
text_token_len = batch["text_token_len"].to(device)
|
||||
tts_text = batch["tts_text"]
|
||||
tts_index = batch["tts_index"]
|
||||
tts_text_token = batch["tts_text_token"].to(device)
|
||||
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
||||
|
||||
@@ -132,5 +132,6 @@ def main():
|
||||
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
|
||||
dist.destroy_process_group(group_join)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user