mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update dpo
This commit is contained in:
@@ -49,5 +49,7 @@ if __name__ == "__main__":
|
||||
type=str)
|
||||
parser.add_argument('--des_dir',
|
||||
type=str)
|
||||
parser.add_argument('--ref_model',
|
||||
type=str)
|
||||
args = parser.parse_args()
|
||||
main()
|
||||
|
||||
49
examples/libritts/cosyvoice/local/prepare_reject_sample.py
Normal file
49
examples/libritts/cosyvoice/local/prepare_reject_sample.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import torch, torchaudio
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def main():
|
||||
cosyvoice = CosyVoice2(args.ref_model)
|
||||
|
||||
utt2wav, utt2text = {}, {}
|
||||
with open('{}/wav.scp'.format(args.src_dir)) as f:
|
||||
for l in f:
|
||||
l = l.split('\n')[0].split()
|
||||
utt2wav[l[0]] = l[1]
|
||||
with open('{}/text'.format(args.src_dir)) as f:
|
||||
for l in f:
|
||||
l = l.split('\n')[0].split()
|
||||
utt2text[l[0]] = ' '.join(l[1:])
|
||||
|
||||
os.makedirs('{}/wav'.format(args.des_dir), exist_ok=True)
|
||||
with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
|
||||
for utt, wav in tqdm(utt2wav.items()):
|
||||
prompt_speech_16k = load_wav(wav, 16000)
|
||||
if prompt_speech_16k.shape[1] >= 30 * 16000:
|
||||
continue
|
||||
speech_list = []
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)):
|
||||
speech_list.append(j['tts_speech'])
|
||||
negative_wav = os.path.abspath('{}/wav/{}'.format(args.des_dir, os.path.basename(wav)))
|
||||
torchaudio.save(negative_wav, torch.concat(speech_list, dim=1), cosyvoice.sample_rate, backend='soundfile')
|
||||
f.write('{} {}\n'.format(utt, negative_wav))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--src_dir',
|
||||
type=str)
|
||||
parser.add_argument('--des_dir',
|
||||
type=str)
|
||||
parser.add_argument('--ref_model',
|
||||
type=str)
|
||||
args = parser.parse_args()
|
||||
main()
|
||||
@@ -51,23 +51,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
# inference
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
|
||||
for mode in sft zero_shot; do
|
||||
python cosyvoice/bin/inference.py --mode $mode \
|
||||
--gpu 0 \
|
||||
--config conf/cosyvoice.yaml \
|
||||
--prompt_data data/test-clean/parquet/data.list \
|
||||
--prompt_utt2data data/test-clean/parquet/utt2data.list \
|
||||
--tts_text `pwd`/tts_text.json \
|
||||
--llm_model $pretrained_model_dir/llm.pt \
|
||||
--flow_model $pretrained_model_dir/flow.pt \
|
||||
--hifigan_model $pretrained_model_dir/hift.pt \
|
||||
--result_dir `pwd`/exp/cosyvoice/test-clean/$mode
|
||||
done
|
||||
fi
|
||||
|
||||
# train llm
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
|
||||
Reference in New Issue
Block a user