mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
add flow decoder cache
This commit is contained in:
@@ -23,7 +23,7 @@ from torch.utils.data import DataLoader
|
||||
import torchaudio
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from tqdm import tqdm
|
||||
from cosyvoice.cli.model import CosyVoiceModel
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||
from cosyvoice.dataset.dataset import Dataset
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ def get_args():
|
||||
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
||||
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
||||
parser.add_argument('--tts_text', required=True, help='tts input file')
|
||||
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
||||
parser.add_argument('--llm_model', required=True, help='llm model file')
|
||||
parser.add_argument('--flow_model', required=True, help='flow model file')
|
||||
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
||||
@@ -59,10 +60,18 @@ def main():
|
||||
# Init cosyvoice models from configs
|
||||
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
try:
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
|
||||
model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16=False)
|
||||
except Exception:
|
||||
try:
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16=False)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
|
||||
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
||||
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
||||
|
||||
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
||||
@@ -104,7 +113,7 @@ def main():
|
||||
tts_speeches = torch.concat(tts_speeches, dim=1)
|
||||
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
||||
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
||||
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
|
||||
torchaudio.save(tts_fn, tts_speeches, sample_rate=configs['sample_rate'], backend='soundfile')
|
||||
f.write('{} {}\n'.format(tts_key, tts_fn))
|
||||
f.flush()
|
||||
f.close()
|
||||
|
||||
Reference in New Issue
Block a user