From 84e41729eaf60691dda2d3aefcfa365ef0381be5 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 29 Jan 2026 10:29:22 +0000 Subject: [PATCH] update --- cosyvoice/flow/flow.py | 9 ++++++--- cosyvoice/llm/llm.py | 15 +++++++++++---- cosyvoice/utils/onnx.py | 8 ++------ examples/libritts/cosyvoice3/conf/cosyvoice3.yaml | 1 + 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index 79c0f98..c255186 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -189,7 +189,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module): device: torch.device, ) -> Dict[str, Optional[torch.Tensor]]: if 'speech_token' not in batch: - token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len']) + token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device) else: token = batch['speech_token'].to(device) token_len = batch['speech_token_len'].to(device) @@ -322,8 +322,11 @@ class CausalMaskedDiffWithDiT(torch.nn.Module): batch: dict, device: torch.device, ) -> Dict[str, Optional[torch.Tensor]]: - token = batch['speech_token'].to(device) - token_len = batch['speech_token_len'].to(device) + if 'speech_token' not in batch: + token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device) + else: + token = batch['speech_token'].to(device) + token_len = batch['speech_token_len'].to(device) feat = batch['speech_feat'].to(device) feat_len = batch['speech_feat_len'].to(device) embedding = batch['embedding'].to(device) diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 9a15109..b17bd3a 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -367,8 +367,11 @@ class Qwen2LM(TransformerLM): """ text_token = batch['text_token'].to(device) text_token_len = batch['text_token_len'].to(device) - speech_token = batch['speech_token'].to(device) - speech_token_len = batch['speech_token_len'].to(device) + if 'speech_token' not in batch: + speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device) + else: + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) # 1. encode text_token text_token_emb = self.llm.model.model.embed_tokens(text_token) @@ -686,8 +689,12 @@ class CosyVoice3LM(Qwen2LM): """ text_token = batch['text_token'].to(device) text_token_len = batch['text_token_len'].to(device) - speech_token = batch['speech_token'].to(device) - speech_token_len = batch['speech_token_len'].to(device) + if 'speech_token' not in batch: + speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device) + else: + speech_token = batch['speech_token'].to(device) + speech_token_len = batch['speech_token_len'].to(device) + # NOTE should append instruct_token to sequence, not implemented yet instruct_token = batch['instruct_token'].to(device) instruct_token_len = batch['instruct_token_len'].to(device) diff --git a/cosyvoice/utils/onnx.py b/cosyvoice/utils/onnx.py index 44ba765..b459eb5 100644 --- a/cosyvoice/utils/onnx.py +++ b/cosyvoice/utils/onnx.py @@ -1,11 +1,7 @@ import onnxruntime import torch, random -from torch import nn import os -import whisper -import numpy as np import torchaudio.compliance.kaldi as kaldi -import torch.nn.functional as F class SpeechTokenExtractor(): @@ -18,13 +14,13 @@ class SpeechTokenExtractor(): sess_options=option, providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})]) - def inference(self, feat, feat_lengths): + def inference(self, feat, feat_lengths, device): speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.transpose(1, 2).detach().cpu().numpy(), self.speech_tokenizer_session.get_inputs()[1].name: feat_lengths.detach().cpu().numpy()})[0] - return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device) + return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device) class EmbeddingExtractor(): diff --git a/examples/libritts/cosyvoice3/conf/cosyvoice3.yaml b/examples/libritts/cosyvoice3/conf/cosyvoice3.yaml index 85ef59b..36dfee4 100644 --- a/examples/libritts/cosyvoice3/conf/cosyvoice3.yaml +++ b/examples/libritts/cosyvoice3/conf/cosyvoice3.yaml @@ -150,6 +150,7 @@ compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank feat_extractor: !ref num_frames: 960 compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank + num_frames: 960 compute_f0: !name:cosyvoice.dataset.processor.compute_f0 sample_rate: !ref hop_size: 480