mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
Merge branch 'dev/lyuxiang.lx' of http://gitlab.alibaba-inc.com/NLS/CosyVoice into dev/lyuxiang.lx
This commit is contained in:
@@ -1,126 +0,0 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torchaudio
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from tqdm import tqdm
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||
from cosyvoice.dataset.dataset import Dataset
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='inference with your model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
||||
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
||||
parser.add_argument('--tts_text', required=True, help='tts input file')
|
||||
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
||||
parser.add_argument('--llm_model', required=True, help='llm model file')
|
||||
parser.add_argument('--flow_model', required=True, help='flow model file')
|
||||
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
||||
parser.add_argument('--gpu',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='gpu id for this rank, -1 for cpu')
|
||||
parser.add_argument('--mode',
|
||||
default='sft',
|
||||
choices=['sft', 'zero_shot'],
|
||||
help='inference mode')
|
||||
parser.add_argument('--result_dir', required=True, help='asr result file')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||
|
||||
# Init cosyvoice models from configs
|
||||
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
try:
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
|
||||
model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
|
||||
except Exception:
|
||||
try:
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
|
||||
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
||||
|
||||
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
||||
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
||||
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
||||
|
||||
sample_rate = configs['sample_rate']
|
||||
del configs
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
fn = os.path.join(args.result_dir, 'wav.scp')
|
||||
f = open(fn, 'w')
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(test_data_loader)):
|
||||
utts = batch["utts"]
|
||||
assert len(utts) == 1, "inference mode only support batchsize 1"
|
||||
text_token = batch["text_token"].to(device)
|
||||
text_token_len = batch["text_token_len"].to(device)
|
||||
tts_index = batch["tts_index"]
|
||||
tts_text_token = batch["tts_text_token"].to(device)
|
||||
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
||||
speech_token = batch["speech_token"].to(device)
|
||||
speech_token_len = batch["speech_token_len"].to(device)
|
||||
speech_feat = batch["speech_feat"].to(device)
|
||||
speech_feat_len = batch["speech_feat_len"].to(device)
|
||||
utt_embedding = batch["utt_embedding"].to(device)
|
||||
spk_embedding = batch["spk_embedding"].to(device)
|
||||
if args.mode == 'sft':
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
||||
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
||||
else:
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
||||
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
||||
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
||||
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
||||
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
||||
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
||||
tts_speeches = []
|
||||
for model_output in model.tts(**model_input):
|
||||
tts_speeches.append(model_output['tts_speech'])
|
||||
tts_speeches = torch.concat(tts_speeches, dim=1)
|
||||
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
||||
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
||||
torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
|
||||
f.write('{} {}\n'.format(tts_key, tts_fn))
|
||||
f.flush()
|
||||
f.close()
|
||||
logging.info('Result wav.scp saved in {}'.format(fn))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
|
||||
main()
|
||||
@@ -38,9 +38,6 @@ class CosyVoiceModel:
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
||||
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
||||
self.token_overlap_len = 20
|
||||
@@ -249,9 +246,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
# NOTE must matching training static_chunk_size
|
||||
self.token_hop_len = 25
|
||||
# hift cache
|
||||
@@ -398,9 +392,6 @@ class CosyVoice3Model(CosyVoice2Model):
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
# NOTE must matching training static_chunk_size
|
||||
self.token_hop_len = 25
|
||||
# rtf and decoding related
|
||||
|
||||
@@ -242,6 +242,10 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
||||
for sample in data:
|
||||
assert 'text' in sample
|
||||
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
||||
if 'instruct' in sample:
|
||||
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
|
||||
else:
|
||||
sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
|
||||
yield sample
|
||||
|
||||
|
||||
@@ -390,6 +394,9 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
||||
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
||||
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
||||
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
||||
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
|
||||
instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
|
||||
instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
|
||||
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
||||
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
||||
batch = {
|
||||
@@ -403,6 +410,8 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
||||
"text": text,
|
||||
"text_token": text_token,
|
||||
"text_token_len": text_token_len,
|
||||
"instruct_token": instruct_token,
|
||||
"instruct_token_len": instruct_token_len,
|
||||
"utt_embedding": utt_embedding,
|
||||
"spk_embedding": spk_embedding,
|
||||
}
|
||||
|
||||
@@ -91,12 +91,13 @@ class ConditionalCFM(BASECFM):
|
||||
sol = []
|
||||
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
for step in range(1, len(t_span)):
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
x_in[:] = x
|
||||
|
||||
@@ -674,6 +674,9 @@ class CosyVoice3LM(Qwen2LM):
|
||||
text_token_len = batch['text_token_len'].to(device)
|
||||
speech_token = batch['speech_token'].to(device)
|
||||
speech_token_len = batch['speech_token_len'].to(device)
|
||||
# NOTE should append instruct_token to sequence, not implemented yet
|
||||
instruct_token = batch['instruct_token'].to(device)
|
||||
instruct_token_len = batch['instruct_token_len'].to(device)
|
||||
|
||||
# 1. encode text_token
|
||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||
|
||||
@@ -40,6 +40,11 @@ def main():
|
||||
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
|
||||
for k, v in spk2utt.items():
|
||||
f.write('{} {}\n'.format(k, ' '.join(v)))
|
||||
if args.instruct is True:
|
||||
with open('{}/instruct'.format(args.des_dir), 'w') as f:
|
||||
for k, v in utt2text.items():
|
||||
# NOTE in CosyVoice3, we add instruct in sequence
|
||||
f.write('{} You are a helpful assistant.<|endofprompt|>\n'.format(k, v))
|
||||
return
|
||||
|
||||
|
||||
@@ -49,7 +54,9 @@ if __name__ == "__main__":
|
||||
type=str)
|
||||
parser.add_argument('--des_dir',
|
||||
type=str)
|
||||
parser.add_argument('--ref_model',
|
||||
type=str)
|
||||
parser.add_argument('--instruct',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='create instruct file or not')
|
||||
args = parser.parse_args()
|
||||
main()
|
||||
|
||||
@@ -20,7 +20,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mkdir -p data/$x
|
||||
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
|
||||
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x --instruct
|
||||
done
|
||||
fi
|
||||
|
||||
@@ -46,6 +46,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
mkdir -p data/$x/parquet
|
||||
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
--num_processes 10 \
|
||||
--instruct \
|
||||
--src_dir data/$x \
|
||||
--des_dir data/$x/parquet
|
||||
done
|
||||
|
||||
@@ -37,6 +37,8 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
|
||||
speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
|
||||
if args.dpo:
|
||||
reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
|
||||
if args.instruct:
|
||||
instruct_list = [utt2instruct[utt] for utt in utt_list]
|
||||
|
||||
# 保存到parquet,utt2parquet_file,spk2parquet_file
|
||||
df = pd.DataFrame()
|
||||
@@ -50,6 +52,8 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
|
||||
df['speech_token'] = speech_token_list
|
||||
if args.dpo:
|
||||
df['reject_speech_token'] = reject_speech_token_list
|
||||
if args.instruct:
|
||||
df['instruct'] = instruct_list
|
||||
df.to_parquet(parquet_file)
|
||||
with open(utt2parquet_file, 'w') as f:
|
||||
json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
|
||||
@@ -68,6 +72,10 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=1,
|
||||
help='num processes for make parquets')
|
||||
parser.add_argument('--instruct',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='has instruct file or not')
|
||||
parser.add_argument('--src_dir',
|
||||
type=str)
|
||||
parser.add_argument('--des_dir',
|
||||
@@ -91,6 +99,11 @@ if __name__ == "__main__":
|
||||
for l in f:
|
||||
l = l.replace('\n', '').split()
|
||||
utt2spk[l[0]] = l[1]
|
||||
if args.instruct is True:
|
||||
with open('{}/instruct'.format(args.src_dir)) as f:
|
||||
for l in f:
|
||||
l = l.replace('\n', '').split()
|
||||
utt2instruct[l[0]] = ' '.join(l[1:])
|
||||
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
|
||||
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
|
||||
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
|
||||
|
||||
Reference in New Issue
Block a user