mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
online feature
This commit is contained in:
@@ -49,6 +49,7 @@ def get_args():
|
||||
parser.add_argument('--train_data', required=True, help='train data file')
|
||||
parser.add_argument('--cv_data', required=True, help='cv data file')
|
||||
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
||||
parser.add_argument('--onnx_path', required=False, help='onnx path, which is required for online feature extraction')
|
||||
parser.add_argument('--checkpoint', help='checkpoint model')
|
||||
parser.add_argument('--model_dir', required=True, help='save model dir')
|
||||
parser.add_argument('--tensorboard_dir',
|
||||
@@ -96,6 +97,7 @@ def get_args():
|
||||
@record
|
||||
def main():
|
||||
args = get_args()
|
||||
os.environ['onnx_path'] = args.onnx_path
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
# gan train has some special initialization logic
|
||||
@@ -104,12 +106,10 @@ def main():
|
||||
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
||||
if gan is True:
|
||||
override_dict.pop('hift')
|
||||
try:
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
|
||||
except Exception:
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides=override_dict)
|
||||
if args.qwen_pretrain_path is not None:
|
||||
override_dict['qwen_pretrain_path'] = args.qwen_pretrain_path
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides=override_dict)
|
||||
if gan is True:
|
||||
configs['train_conf'] = configs['train_conf_gan']
|
||||
configs['train_conf'].update(vars(args))
|
||||
|
||||
@@ -16,12 +16,13 @@ import random
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
import pyworld as pw
|
||||
|
||||
from cosyvoice.utils.onnx import embedding_extractor, online_feature
|
||||
|
||||
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
||||
|
||||
@@ -92,9 +93,9 @@ def filter(data,
|
||||
continue
|
||||
if len(sample['text_token']) > token_max_length:
|
||||
continue
|
||||
if len(sample['speech_token']) == 0:
|
||||
if online_feature is False and len(sample['speech_token']) == 0:
|
||||
continue
|
||||
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
||||
if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
||||
continue
|
||||
if num_frames != 0:
|
||||
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
||||
@@ -155,7 +156,7 @@ def truncate(data, truncate_length=24576, mode='train'):
|
||||
|
||||
def compute_fbank(data,
|
||||
feat_extractor,
|
||||
token_mel_ratio=0,
|
||||
num_frames=-1,
|
||||
mode='train'):
|
||||
""" Extract fbank
|
||||
|
||||
@@ -170,14 +171,11 @@ def compute_fbank(data,
|
||||
assert 'speech' in sample
|
||||
assert 'utt' in sample
|
||||
assert 'text_token' in sample
|
||||
waveform = sample['speech']
|
||||
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
||||
if token_mel_ratio != 0:
|
||||
# trim to align speech_token and speech_feat
|
||||
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
|
||||
feat = feat[:token_mel_ratio * token_len]
|
||||
sample["speech_token"] = sample["speech_token"][:token_len]
|
||||
sample['speech_feat'] = feat
|
||||
# NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
|
||||
if num_frames != -1:
|
||||
index = int(np.ceil(sample['speech'].shape[1] / num_frames))
|
||||
sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
|
||||
sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
|
||||
yield sample
|
||||
|
||||
|
||||
@@ -216,6 +214,10 @@ def parse_embedding(data, normalize, mode='train'):
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
|
||||
speech_16k = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
|
||||
embedding = embedding_extractor.inference(speech_16k)
|
||||
sample['spk_embedding'] = sample['utt_embedding'] = embedding
|
||||
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
||||
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
||||
if normalize:
|
||||
@@ -256,13 +258,14 @@ def shuffle(data, shuffle_size=10000, mode='train'):
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
buf = []
|
||||
yield_size = int(shuffle_size / 2)
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) >= shuffle_size:
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
for x in buf[:yield_size]:
|
||||
yield x
|
||||
buf = []
|
||||
buf = buf[yield_size:]
|
||||
# The sample left over
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
@@ -420,10 +423,6 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
||||
padding_value=0)
|
||||
batch["pitch_feat"] = pitch_feat
|
||||
batch["pitch_feat_len"] = pitch_feat_len
|
||||
else:
|
||||
# only gan train needs speech, delete it to save memory
|
||||
del batch["speech"]
|
||||
del batch["speech_len"]
|
||||
if dpo is True:
|
||||
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
|
||||
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from omegaconf import DictConfig
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
from cosyvoice.utils.onnx import SpeechTokenExtractor
|
||||
|
||||
|
||||
class MaskedDiffWithXvec(torch.nn.Module):
|
||||
|
||||
@@ -28,6 +28,7 @@ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
||||
from cosyvoice.utils.common import th_accuracy
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
from cosyvoice.utils.onnx import SpeechTokenExtractor
|
||||
|
||||
|
||||
class TransformerLM(torch.nn.Module):
|
||||
|
||||
59
cosyvoice/utils/onnx.py
Normal file
59
cosyvoice/utils/onnx.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import onnxruntime
|
||||
import torch, random
|
||||
from torch import nn
|
||||
import os
|
||||
import whisper
|
||||
import numpy as np
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SpeechTokenExtractor():
|
||||
def __init__(self, model_path):
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.speech_tokenizer_session = onnxruntime.InferenceSession(model_path,
|
||||
sess_options=option,
|
||||
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
|
||||
|
||||
def inference(self, feat, feat_lengths, device):
|
||||
ort_out = self.speech_tokenizer_session.run(None,
|
||||
{self.speech_tokenizer_session.get_inputs()[0].name:
|
||||
feat.detach().cpu().numpy(),
|
||||
self.speech_tokenizer_session.get_inputs()[1].name:
|
||||
feat_lengths.detach().cpu().numpy()})
|
||||
speech_token, speech_token_embedding = ort_out[0], ort_out[1]
|
||||
return torch.tensor(speech_token).to(device), (feat_lengths / 2).to(torch.int32).to(device)
|
||||
|
||||
|
||||
class EmbeddingExtractor():
|
||||
def __init__(self, model_path):
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.max_len = 10 * 16000
|
||||
self.campplus_session = onnxruntime.InferenceSession(model_path,
|
||||
sess_options=option,
|
||||
providers=["CPUExecutionProvider"])
|
||||
|
||||
def inference(self, speech):
|
||||
if speech.shape[1] > self.max_len:
|
||||
start_index = random.randint(0, speech.shape[1] - self.max_len)
|
||||
speech = speech[:, start_index: start_index + self.max_len]
|
||||
feat = kaldi.fbank(speech,
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
sample_frequency=16000)
|
||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||
embedding = self.campplus_session.run(None,
|
||||
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||
return torch.tensor(embedding).to(speech.device)
|
||||
|
||||
# singleton mode, only initialized once
|
||||
onnx_path = os.environ.get('onnx_path')
|
||||
if onnx_path is not None:
|
||||
embedding_extractor, online_feature = EmbeddingExtractor(model_path=os.path.join(onnx_path, 'campplus.onnx')), True
|
||||
else:
|
||||
embedding_extractor, online_feature = None, False
|
||||
Reference in New Issue
Block a user