mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
add instruct
This commit is contained in:
@@ -242,6 +242,10 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
|||||||
for sample in data:
|
for sample in data:
|
||||||
assert 'text' in sample
|
assert 'text' in sample
|
||||||
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
||||||
|
if 'instruct' in sample:
|
||||||
|
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
|
||||||
|
else:
|
||||||
|
sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
|
||||||
yield sample
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
@@ -390,6 +394,9 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
|||||||
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
||||||
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
||||||
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
||||||
|
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
|
||||||
|
instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
|
||||||
|
instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
|
||||||
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
||||||
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
||||||
batch = {
|
batch = {
|
||||||
@@ -403,6 +410,8 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
|||||||
"text": text,
|
"text": text,
|
||||||
"text_token": text_token,
|
"text_token": text_token,
|
||||||
"text_token_len": text_token_len,
|
"text_token_len": text_token_len,
|
||||||
|
"instruct_token": instruct_token,
|
||||||
|
"instruct_token_len": instruct_token_len,
|
||||||
"utt_embedding": utt_embedding,
|
"utt_embedding": utt_embedding,
|
||||||
"spk_embedding": spk_embedding,
|
"spk_embedding": spk_embedding,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -674,6 +674,9 @@ class CosyVoice3LM(Qwen2LM):
|
|||||||
text_token_len = batch['text_token_len'].to(device)
|
text_token_len = batch['text_token_len'].to(device)
|
||||||
speech_token = batch['speech_token'].to(device)
|
speech_token = batch['speech_token'].to(device)
|
||||||
speech_token_len = batch['speech_token_len'].to(device)
|
speech_token_len = batch['speech_token_len'].to(device)
|
||||||
|
# NOTE should append instruct_token to sequence, not implemented yet
|
||||||
|
instruct_token = batch['instruct_token'].to(device)
|
||||||
|
instruct_token_len = batch['instruct_token_len'].to(device)
|
||||||
|
|
||||||
# 1. encode text_token
|
# 1. encode text_token
|
||||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||||
|
|||||||
@@ -40,6 +40,11 @@ def main():
|
|||||||
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
|
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
|
||||||
for k, v in spk2utt.items():
|
for k, v in spk2utt.items():
|
||||||
f.write('{} {}\n'.format(k, ' '.join(v)))
|
f.write('{} {}\n'.format(k, ' '.join(v)))
|
||||||
|
if args.instruct is True:
|
||||||
|
with open('{}/instruct'.format(args.des_dir), 'w') as f:
|
||||||
|
for k, v in utt2text.items():
|
||||||
|
# NOTE in CosyVoice3, we add instruct in sequence
|
||||||
|
f.write('{} You are a helpful assistant.<|endofprompt|>\n'.format(k, v))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@@ -49,7 +54,9 @@ if __name__ == "__main__":
|
|||||||
type=str)
|
type=str)
|
||||||
parser.add_argument('--des_dir',
|
parser.add_argument('--des_dir',
|
||||||
type=str)
|
type=str)
|
||||||
parser.add_argument('--ref_model',
|
parser.add_argument('--instruct',
|
||||||
type=str)
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help='create instruct file or not')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
|||||||
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
|
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
|
||||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
mkdir -p data/$x
|
mkdir -p data/$x
|
||||||
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
|
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x --instruct
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@@ -46,6 +46,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
mkdir -p data/$x/parquet
|
mkdir -p data/$x/parquet
|
||||||
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||||
--num_processes 10 \
|
--num_processes 10 \
|
||||||
|
--instruct \
|
||||||
--src_dir data/$x \
|
--src_dir data/$x \
|
||||||
--des_dir data/$x/parquet
|
--des_dir data/$x/parquet
|
||||||
done
|
done
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
|
|||||||
speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
|
speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
|
||||||
if args.dpo:
|
if args.dpo:
|
||||||
reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
|
reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
|
||||||
|
if args.instruct:
|
||||||
|
instruct_list = [utt2instruct[utt] for utt in utt_list]
|
||||||
|
|
||||||
# 保存到parquet,utt2parquet_file,spk2parquet_file
|
# 保存到parquet,utt2parquet_file,spk2parquet_file
|
||||||
df = pd.DataFrame()
|
df = pd.DataFrame()
|
||||||
@@ -50,6 +52,8 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
|
|||||||
df['speech_token'] = speech_token_list
|
df['speech_token'] = speech_token_list
|
||||||
if args.dpo:
|
if args.dpo:
|
||||||
df['reject_speech_token'] = reject_speech_token_list
|
df['reject_speech_token'] = reject_speech_token_list
|
||||||
|
if args.instruct:
|
||||||
|
df['instruct'] = instruct_list
|
||||||
df.to_parquet(parquet_file)
|
df.to_parquet(parquet_file)
|
||||||
with open(utt2parquet_file, 'w') as f:
|
with open(utt2parquet_file, 'w') as f:
|
||||||
json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
|
json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
|
||||||
@@ -68,6 +72,10 @@ if __name__ == "__main__":
|
|||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help='num processes for make parquets')
|
help='num processes for make parquets')
|
||||||
|
parser.add_argument('--instruct',
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help='has instruct file or not')
|
||||||
parser.add_argument('--src_dir',
|
parser.add_argument('--src_dir',
|
||||||
type=str)
|
type=str)
|
||||||
parser.add_argument('--des_dir',
|
parser.add_argument('--des_dir',
|
||||||
@@ -91,6 +99,11 @@ if __name__ == "__main__":
|
|||||||
for l in f:
|
for l in f:
|
||||||
l = l.replace('\n', '').split()
|
l = l.replace('\n', '').split()
|
||||||
utt2spk[l[0]] = l[1]
|
utt2spk[l[0]] = l[1]
|
||||||
|
if args.instruct is True:
|
||||||
|
with open('{}/instruct'.format(args.src_dir)) as f:
|
||||||
|
for l in f:
|
||||||
|
l = l.replace('\n', '').split()
|
||||||
|
utt2instruct[l[0]] = ' '.join(l[1:])
|
||||||
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
|
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
|
||||||
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
|
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
|
||||||
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
|
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
|
||||||
|
|||||||
Reference in New Issue
Block a user