Merge branch 'dev/lyuxiang.lx' of http://gitlab.alibaba-inc.com/NLS/CosyVoice into dev/lyuxiang.lx

This commit is contained in:
hengwu.zty
2025-12-12 18:39:18 +08:00
8 changed files with 43 additions and 144 deletions

View File

@@ -1,126 +0,0 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.dataset.dataset import Dataset
def get_args():
parser = argparse.ArgumentParser(description='inference with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
parser.add_argument('--tts_text', required=True, help='tts input file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--llm_model', required=True, help='llm model file')
parser.add_argument('--flow_model', required=True, help='flow model file')
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--mode',
default='sft',
choices=['sft', 'zero_shot'],
help='inference mode')
parser.add_argument('--result_dir', required=True, help='asr result file')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
except Exception:
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
except Exception:
raise TypeError('no valid model_type!')
model.load(args.llm_model, args.flow_model, args.hifigan_model)
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
sample_rate = configs['sample_rate']
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
with torch.no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
text_token = batch["text_token"].to(device)
text_token_len = batch["text_token_len"].to(device)
tts_index = batch["tts_index"]
tts_text_token = batch["tts_text_token"].to(device)
tts_text_token_len = batch["tts_text_token_len"].to(device)
speech_token = batch["speech_token"].to(device)
speech_token_len = batch["speech_token_len"].to(device)
speech_feat = batch["speech_feat"].to(device)
speech_feat_len = batch["speech_feat_len"].to(device)
utt_embedding = batch["utt_embedding"].to(device)
spk_embedding = batch["spk_embedding"].to(device)
if args.mode == 'sft':
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
else:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': text_token, 'prompt_text_len': text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
tts_speeches = []
for model_output in model.tts(**model_input):
tts_speeches.append(model_output['tts_speech'])
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush()
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
if __name__ == '__main__':
logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
main()

View File

@@ -38,9 +38,6 @@ class CosyVoiceModel:
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 20
@@ -249,9 +246,6 @@ class CosyVoice2Model(CosyVoiceModel):
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# hift cache
@@ -398,9 +392,6 @@ class CosyVoice3Model(CosyVoice2Model):
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# rtf and decoding related

View File

@@ -242,6 +242,10 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if 'instruct' in sample:
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
else:
sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
yield sample
@@ -390,6 +394,9 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
@@ -403,6 +410,8 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,
"instruct_token": instruct_token,
"instruct_token_len": instruct_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}

View File

@@ -91,12 +91,13 @@ class ConditionalCFM(BASECFM):
sol = []
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x

View File

@@ -674,6 +674,9 @@ class CosyVoice3LM(Qwen2LM):
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# NOTE should append instruct_token to sequence, not implemented yet
instruct_token = batch['instruct_token'].to(device)
instruct_token_len = batch['instruct_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)

View File

@@ -40,6 +40,11 @@ def main():
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
for k, v in spk2utt.items():
f.write('{} {}\n'.format(k, ' '.join(v)))
if args.instruct is True:
with open('{}/instruct'.format(args.des_dir), 'w') as f:
for k, v in utt2text.items():
# NOTE in CosyVoice3, we add instruct in sequence
f.write('{} You are a helpful assistant.<|endofprompt|>\n'.format(k, v))
return
@@ -49,7 +54,9 @@ if __name__ == "__main__":
type=str)
parser.add_argument('--des_dir',
type=str)
parser.add_argument('--ref_model',
type=str)
parser.add_argument('--instruct',
action='store_true',
default=False,
help='create instruct file or not')
args = parser.parse_args()
main()

View File

@@ -20,7 +20,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x --instruct
done
fi
@@ -46,6 +46,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--instruct \
--src_dir data/$x \
--des_dir data/$x/parquet
done

View File

@@ -37,6 +37,8 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
if args.dpo:
reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
if args.instruct:
instruct_list = [utt2instruct[utt] for utt in utt_list]
# 保存到parquet,utt2parquet_file,spk2parquet_file
df = pd.DataFrame()
@@ -50,6 +52,8 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
df['speech_token'] = speech_token_list
if args.dpo:
df['reject_speech_token'] = reject_speech_token_list
if args.instruct:
df['instruct'] = instruct_list
df.to_parquet(parquet_file)
with open(utt2parquet_file, 'w') as f:
json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
@@ -68,6 +72,10 @@ if __name__ == "__main__":
type=int,
default=1,
help='num processes for make parquets')
parser.add_argument('--instruct',
action='store_true',
default=False,
help='has instruct file or not')
parser.add_argument('--src_dir',
type=str)
parser.add_argument('--des_dir',
@@ -91,6 +99,11 @@ if __name__ == "__main__":
for l in f:
l = l.replace('\n', '').split()
utt2spk[l[0]] = l[1]
if args.instruct is True:
with open('{}/instruct'.format(args.src_dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2instruct[l[0]] = ' '.join(l[1:])
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))