This commit is contained in:
lyuxiang.lx
2026-01-29 10:29:22 +00:00
parent f26cde56df
commit 84e41729ea
4 changed files with 20 additions and 13 deletions

View File

@@ -189,7 +189,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
device: torch.device, device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]: ) -> Dict[str, Optional[torch.Tensor]]:
if 'speech_token' not in batch: if 'speech_token' not in batch:
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len']) token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else: else:
token = batch['speech_token'].to(device) token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device) token_len = batch['speech_token_len'].to(device)
@@ -322,6 +322,9 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
batch: dict, batch: dict,
device: torch.device, device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]: ) -> Dict[str, Optional[torch.Tensor]]:
if 'speech_token' not in batch:
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device) token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device) token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device) feat = batch['speech_feat'].to(device)

View File

@@ -367,6 +367,9 @@ class Qwen2LM(TransformerLM):
""" """
text_token = batch['text_token'].to(device) text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device) text_token_len = batch['text_token_len'].to(device)
if 'speech_token' not in batch:
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
speech_token = batch['speech_token'].to(device) speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device) speech_token_len = batch['speech_token_len'].to(device)
@@ -686,8 +689,12 @@ class CosyVoice3LM(Qwen2LM):
""" """
text_token = batch['text_token'].to(device) text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device) text_token_len = batch['text_token_len'].to(device)
if 'speech_token' not in batch:
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
speech_token = batch['speech_token'].to(device) speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device) speech_token_len = batch['speech_token_len'].to(device)
# NOTE should append instruct_token to sequence, not implemented yet # NOTE should append instruct_token to sequence, not implemented yet
instruct_token = batch['instruct_token'].to(device) instruct_token = batch['instruct_token'].to(device)
instruct_token_len = batch['instruct_token_len'].to(device) instruct_token_len = batch['instruct_token_len'].to(device)

View File

@@ -1,11 +1,7 @@
import onnxruntime import onnxruntime
import torch, random import torch, random
from torch import nn
import os import os
import whisper
import numpy as np
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
import torch.nn.functional as F
class SpeechTokenExtractor(): class SpeechTokenExtractor():
@@ -18,13 +14,13 @@ class SpeechTokenExtractor():
sess_options=option, sess_options=option,
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})]) providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
def inference(self, feat, feat_lengths): def inference(self, feat, feat_lengths, device):
speech_token = self.speech_tokenizer_session.run(None, speech_token = self.speech_tokenizer_session.run(None,
{self.speech_tokenizer_session.get_inputs()[0].name: {self.speech_tokenizer_session.get_inputs()[0].name:
feat.transpose(1, 2).detach().cpu().numpy(), feat.transpose(1, 2).detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name: self.speech_tokenizer_session.get_inputs()[1].name:
feat_lengths.detach().cpu().numpy()})[0] feat_lengths.detach().cpu().numpy()})[0]
return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device) return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
class EmbeddingExtractor(): class EmbeddingExtractor():

View File

@@ -150,6 +150,7 @@ compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor> feat_extractor: !ref <feat_extractor>
num_frames: 960 num_frames: 960
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
num_frames: 960
compute_f0: !name:cosyvoice.dataset.processor.compute_f0 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate> sample_rate: !ref <sample_rate>
hop_size: 480 hop_size: 480