This commit is contained in:
lyuxiang.lx
2026-01-29 06:13:36 +00:00
parent 66b80dbccb
commit f26cde56df
7 changed files with 90 additions and 73 deletions

View File

@@ -256,6 +256,10 @@ class CosyVoice2Model(CosyVoiceModel):
self.fp16 = fp16 self.fp16 = fp16
# NOTE must matching training static_chunk_size # NOTE must matching training static_chunk_size
self.token_hop_len = 25 self.token_hop_len = 25
# NOTE increase token_hop_len incrementally to avoid duplicate inference
self.token_max_hop_len = 4 * self.token_hop_len
self.stream_scale_factor = 2
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
# hift cache # hift cache
self.mel_cache_len = 8 self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480) self.source_cache_len = int(self.mel_cache_len * 480)
@@ -353,6 +357,7 @@ class CosyVoice2Model(CosyVoiceModel):
stream=stream, stream=stream,
finalize=False) finalize=False)
token_offset += this_token_hop_len token_offset += this_token_hop_len
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
yield {'tts_speech': this_tts_speech.cpu()} yield {'tts_speech': this_tts_speech.cpu()}
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len: if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
break break
@@ -403,6 +408,10 @@ class CosyVoice3Model(CosyVoice2Model):
self.fp16 = fp16 self.fp16 = fp16
# NOTE must matching training static_chunk_size # NOTE must matching training static_chunk_size
self.token_hop_len = 25 self.token_hop_len = 25
# NOTE increase token_hop_len incrementally to avoid duplicate inference
self.token_max_hop_len = 4 * self.token_hop_len
self.stream_scale_factor = 2
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
# rtf and decoding related # rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock() self.lock = threading.Lock()

View File

@@ -17,6 +17,7 @@ import random
import pyarrow.parquet as pq import pyarrow.parquet as pq
from io import BytesIO from io import BytesIO
import numpy as np import numpy as np
import whisper
import torch import torch
import torchaudio import torchaudio
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
@@ -179,6 +180,23 @@ def compute_fbank(data,
yield sample yield sample
def compute_whisper_fbank(data, num_frames=-1, mode='train'):
""" Extract whisper fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
if num_frames != -1:
assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
yield sample
def compute_f0(data, sample_rate, hop_size, mode='train'): def compute_f0(data, sample_rate, hop_size, mode='train'):
""" Extract f0 """ Extract f0
@@ -215,11 +233,12 @@ def parse_embedding(data, normalize, mode='train'):
""" """
for sample in data: for sample in data:
if 'utt_embedding' not in sample and 'spk_embedding' not in sample: if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
speech_16k = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech']) sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
embedding = embedding_extractor.inference(speech_16k) embedding = embedding_extractor.inference(sample['speech_16k'])
sample['spk_embedding'] = sample['utt_embedding'] = embedding sample['spk_embedding'] = sample['utt_embedding'] = embedding
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) else:
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
if normalize: if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
@@ -242,8 +261,6 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if 'instruct' in sample: if 'instruct' in sample:
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special) sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
else:
sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
yield sample yield sample
@@ -371,66 +388,42 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
""" """
for sample in data: for sample in data:
assert isinstance(sample, list) assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
dtype=torch.int32) batch = {}
order = torch.argsort(speech_feat_len, descending=True) batch['utts'] = [sample[i]['utt'] for i in order]
batch['text'] = [sample[i]['text'] for i in order]
utts = [sample[i]['utt'] for i in order]
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
speech = pad_sequence(speech, batch_first=True, padding_value=0)
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']) for i in order] text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0) batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order] speech_feat = [sample[i]['speech_feat'] for i in order]
instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32) batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0) batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0) batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = { if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
"utts": utts, instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
"speech": speech, batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
"speech_len": speech_len, batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
"speech_token": speech_token, if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
"speech_token_len": speech_token_len, whisper_feat = [torch.tensor(sample[i]['whisper_feat']) for i in order]
"speech_feat": speech_feat, batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
"speech_feat_len": speech_feat_len, batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
"text": text, if torch.tensor(['speech_token' in sample[i] for i in order]).all():
"text_token": text_token, speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
"text_token_len": text_token_len, batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
"instruct_token": instruct_token, batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
"instruct_token_len": instruct_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}
if gan is True: if gan is True:
# in gan train, we need pitch_feat # in gan train, we need speech/pitch_feat
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
pitch_feat = [sample[i]['pitch_feat'] for i in order] pitch_feat = [sample[i]['pitch_feat'] for i in order]
pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32) batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
pitch_feat = pad_sequence(pitch_feat, batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
batch_first=True,
padding_value=0)
batch["pitch_feat"] = pitch_feat
batch["pitch_feat_len"] = pitch_feat_len
if dpo is True: if dpo is True:
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order] reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32) batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
reject_speech_token = pad_sequence(reject_speech_token, batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
batch_first=True,
padding_value=0)
batch['reject_speech_token'] = reject_speech_token
batch['reject_speech_token_len'] = reject_speech_token_len
if use_spk_embedding is True: if use_spk_embedding is True:
batch["embedding"] = batch["spk_embedding"] batch["embedding"] = batch["spk_embedding"]
else: else:

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import os, logging
import random import random
from typing import Dict, Optional from typing import Dict, Optional
import torch import torch
@@ -19,7 +19,7 @@ import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from omegaconf import DictConfig from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.onnx import SpeechTokenExtractor from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
class MaskedDiffWithXvec(torch.nn.Module): class MaskedDiffWithXvec(torch.nn.Module):
@@ -180,14 +180,19 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
self.only_mask_loss = only_mask_loss self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len self.pre_lookahead_len = pre_lookahead_len
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
def forward( def forward(
self, self,
batch: dict, batch: dict,
device: torch.device, device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]: ) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device) if 'speech_token' not in batch:
token_len = batch['speech_token_len'].to(device) token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'])
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device) feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device) feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device) embedding = batch['embedding'].to(device)
@@ -309,6 +314,8 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
self.decoder = decoder self.decoder = decoder
self.only_mask_loss = only_mask_loss self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio self.token_mel_ratio = token_mel_ratio
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
def forward( def forward(
self, self,

View File

@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import queue import os, queue
import random import random
import time import time
import threading import threading
@@ -28,7 +28,7 @@ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy from cosyvoice.utils.common import th_accuracy
from cosyvoice.utils.file_utils import logging from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.mask import make_pad_mask from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.onnx import SpeechTokenExtractor from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
class TransformerLM(torch.nn.Module): class TransformerLM(torch.nn.Module):
@@ -301,6 +301,8 @@ class Qwen2LM(TransformerLM):
# 5. vllm related # 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(3)] self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {} self.vllm_output_queue = {}
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None): def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
lm_target, lm_input = [], [] lm_target, lm_input = [], []
@@ -667,6 +669,8 @@ class CosyVoice3LM(Qwen2LM):
# 5. vllm related # 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(200)] self.stop_token_ids = [speech_token_size + i for i in range(200)]
self.vllm_output_queue = {} self.vllm_output_queue = {}
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
def forward( def forward(
self, self,

View File

@@ -18,14 +18,13 @@ class SpeechTokenExtractor():
sess_options=option, sess_options=option,
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})]) providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
def inference(self, feat, feat_lengths, device): def inference(self, feat, feat_lengths):
ort_out = self.speech_tokenizer_session.run(None, speech_token = self.speech_tokenizer_session.run(None,
{self.speech_tokenizer_session.get_inputs()[0].name: {self.speech_tokenizer_session.get_inputs()[0].name:
feat.detach().cpu().numpy(), feat.transpose(1, 2).detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name: self.speech_tokenizer_session.get_inputs()[1].name:
feat_lengths.detach().cpu().numpy()}) feat_lengths.detach().cpu().numpy()})[0]
speech_token, speech_token_embedding = ort_out[0], ort_out[1] return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device)
return torch.tensor(speech_token).to(device), (feat_lengths / 2).to(torch.int32).to(device)
class EmbeddingExtractor(): class EmbeddingExtractor():

View File

@@ -159,6 +159,8 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor> feat_extractor: !ref <feat_extractor>
num_frames: 960 num_frames: 960
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
num_frames: 960
compute_f0: !name:cosyvoice.dataset.processor.compute_f0 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate> sample_rate: !ref <sample_rate>
hop_size: 480 hop_size: 480
@@ -183,6 +185,7 @@ data_pipeline: [
!ref <resample>, !ref <resample>,
!ref <compute_fbank>, !ref <compute_fbank>,
!ref <parse_embedding>, !ref <parse_embedding>,
!ref <compute_whisper_fbank>,
!ref <shuffle>, !ref <shuffle>,
!ref <sort>, !ref <sort>,
!ref <batch>, !ref <batch>,

View File

@@ -149,6 +149,7 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor> feat_extractor: !ref <feat_extractor>
num_frames: 960 num_frames: 960
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
compute_f0: !name:cosyvoice.dataset.processor.compute_f0 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate> sample_rate: !ref <sample_rate>
hop_size: 480 hop_size: 480
@@ -173,6 +174,7 @@ data_pipeline: [
!ref <resample>, !ref <resample>,
!ref <compute_fbank>, !ref <compute_fbank>,
!ref <parse_embedding>, !ref <parse_embedding>,
!ref <compute_whisper_fbank>,
!ref <shuffle>, !ref <shuffle>,
!ref <sort>, !ref <sort>,
!ref <batch>, !ref <batch>,