mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update
This commit is contained in:
@@ -27,7 +27,7 @@ from tqdm import tqdm
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
|
||||
|
||||
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
||||
@@ -56,14 +56,20 @@ def main():
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
||||
try:
|
||||
model = CosyVoice(args.model_dir)
|
||||
except:
|
||||
try:
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except:
|
||||
raise TypeError('no valid model_type!')
|
||||
|
||||
# 1. export flow decoder estimator
|
||||
estimator = cosyvoice.model.flow.decoder.estimator
|
||||
estimator = model.model.flow.decoder.estimator
|
||||
|
||||
device = cosyvoice.model.device
|
||||
batch_size, seq_len = 1, 256
|
||||
out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
|
||||
device = model.model.device
|
||||
batch_size, seq_len = 2, 256
|
||||
out_channels = model.model.flow.decoder.estimator.out_channels
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
||||
torch.onnx.export(
|
||||
estimator,
|
||||
@@ -75,13 +81,11 @@ def main():
|
||||
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
||||
output_names=['estimator_out'],
|
||||
dynamic_axes={
|
||||
'x': {0: 'batch_size', 2: 'seq_len'},
|
||||
'mask': {0: 'batch_size', 2: 'seq_len'},
|
||||
'mu': {0: 'batch_size', 2: 'seq_len'},
|
||||
'cond': {0: 'batch_size', 2: 'seq_len'},
|
||||
't': {0: 'batch_size'},
|
||||
'spks': {0: 'batch_size'},
|
||||
'estimator_out': {0: 'batch_size', 2: 'seq_len'},
|
||||
'x': {2: 'seq_len'},
|
||||
'mask': {2: 'seq_len'},
|
||||
'mu': {2: 'seq_len'},
|
||||
'cond': {2: 'seq_len'},
|
||||
'estimator_out': {2: 'seq_len'},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -94,7 +98,7 @@ def main():
|
||||
sess_options=option, providers=providers)
|
||||
|
||||
for _ in tqdm(range(10)):
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
||||
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
||||
ort_inputs = {
|
||||
'x': x.cpu().numpy(),
|
||||
|
||||
Reference in New Issue
Block a user