This commit is contained in:
lyuxiang.lx
2024-12-31 17:08:11 +08:00
parent 2745d47e92
commit 77d8cf13a3
11 changed files with 163 additions and 158 deletions

View File

@@ -27,7 +27,7 @@ from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
def get_dummy_input(batch_size, seq_len, out_channels, device):
@@ -56,14 +56,20 @@ def main():
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
try:
model = CosyVoice(args.model_dir)
except:
try:
model = CosyVoice2(args.model_dir)
except:
raise TypeError('no valid model_type!')
# 1. export flow decoder estimator
estimator = cosyvoice.model.flow.decoder.estimator
estimator = model.model.flow.decoder.estimator
device = cosyvoice.model.device
batch_size, seq_len = 1, 256
out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
device = model.model.device
batch_size, seq_len = 2, 256
out_channels = model.model.flow.decoder.estimator.out_channels
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
torch.onnx.export(
estimator,
@@ -75,13 +81,11 @@ def main():
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
output_names=['estimator_out'],
dynamic_axes={
'x': {0: 'batch_size', 2: 'seq_len'},
'mask': {0: 'batch_size', 2: 'seq_len'},
'mu': {0: 'batch_size', 2: 'seq_len'},
'cond': {0: 'batch_size', 2: 'seq_len'},
't': {0: 'batch_size'},
'spks': {0: 'batch_size'},
'estimator_out': {0: 'batch_size', 2: 'seq_len'},
'x': {2: 'seq_len'},
'mask': {2: 'seq_len'},
'mu': {2: 'seq_len'},
'cond': {2: 'seq_len'},
'estimator_out': {2: 'seq_len'},
}
)
@@ -94,7 +98,7 @@ def main():
sess_options=option, providers=providers)
for _ in tqdm(range(10)):
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
output_pytorch = estimator(x, mask, mu, t, spks, cond)
ort_inputs = {
'x': x.cpu().numpy(),