add flow decoder cache

This commit is contained in:
lyuxiang.lx
2025-01-23 16:48:13 +08:00
parent 190840b8dc
commit 1c062ab381
21 changed files with 1601 additions and 214 deletions

View File

@@ -23,7 +23,7 @@ from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from cosyvoice.cli.model import CosyVoiceModel
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.dataset.dataset import Dataset
@@ -33,6 +33,7 @@ def get_args():
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
parser.add_argument('--tts_text', required=True, help='tts input file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--llm_model', required=True, help='llm model file')
parser.add_argument('--flow_model', required=True, help='flow model file')
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
@@ -59,10 +60,18 @@ def main():
# Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16=False)
except Exception:
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16=False)
except Exception:
raise TypeError('no valid model_type!')
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
model.load(args.llm_model, args.flow_model, args.hifigan_model)
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
@@ -104,7 +113,7 @@ def main():
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
torchaudio.save(tts_fn, tts_speeches, sample_rate=configs['sample_rate'], backend='soundfile')
f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush()
f.close()