This commit is contained in:
lyuxiang.lx
2026-01-29 10:29:22 +00:00
parent f26cde56df
commit 84e41729ea
4 changed files with 20 additions and 13 deletions

View File

@@ -189,7 +189,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
if 'speech_token' not in batch:
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'])
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
@@ -322,8 +322,11 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
if 'speech_token' not in batch:
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)

View File

@@ -367,8 +367,11 @@ class Qwen2LM(TransformerLM):
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
if 'speech_token' not in batch:
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
@@ -686,8 +689,12 @@ class CosyVoice3LM(Qwen2LM):
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
if 'speech_token' not in batch:
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# NOTE should append instruct_token to sequence, not implemented yet
instruct_token = batch['instruct_token'].to(device)
instruct_token_len = batch['instruct_token_len'].to(device)

View File

@@ -1,11 +1,7 @@
import onnxruntime
import torch, random
from torch import nn
import os
import whisper
import numpy as np
import torchaudio.compliance.kaldi as kaldi
import torch.nn.functional as F
class SpeechTokenExtractor():
@@ -18,13 +14,13 @@ class SpeechTokenExtractor():
sess_options=option,
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
def inference(self, feat, feat_lengths):
def inference(self, feat, feat_lengths, device):
speech_token = self.speech_tokenizer_session.run(None,
{self.speech_tokenizer_session.get_inputs()[0].name:
feat.transpose(1, 2).detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name:
feat_lengths.detach().cpu().numpy()})[0]
return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device)
return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
class EmbeddingExtractor():

View File

@@ -150,6 +150,7 @@ compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
num_frames: 960
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
num_frames: 960
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate>
hop_size: 480