diff --git a/cosyvoice/bin/export_onnx.py b/cosyvoice/bin/export_onnx.py index dd9f009..e4857da 100644 --- a/cosyvoice/bin/export_onnx.py +++ b/cosyvoice/bin/export_onnx.py @@ -27,7 +27,7 @@ from tqdm import tqdm ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) -from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 +from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2, CosyVoice3 from cosyvoice.utils.file_utils import logging @@ -64,7 +64,10 @@ def main(): try: model = CosyVoice2(args.model_dir) except Exception: - raise TypeError('no valid model_type!') + try: + model = CosyVoice3(args.model_dir) + except Exception: + raise TypeError('no valid model_type!') # 1. export flow decoder estimator estimator = model.model.flow.decoder.estimator diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index f4acba1..7731863 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -221,7 +221,7 @@ class CosyVoice3(CosyVoice): self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16) self.model.load('{}/llm.pt'.format(model_dir), '{}/flow.pt'.format(model_dir), - '{}/bigvgan.pt'.format(model_dir)) + '{}/hift.pt'.format(model_dir)) if load_vllm: self.model.load_vllm('{}/vllm'.format(model_dir)) if load_jit: diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 2b6a918..c658996 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -447,7 +447,7 @@ class CosyVoice3Model(CosyVoice2Model): if speed != 1.0: assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode' tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') - tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel) if self.hift_cache_dict[uuid] is not None: tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) return tts_speech \ No newline at end of file diff --git a/cosyvoice/flow/DiT/dit.py b/cosyvoice/flow/DiT/dit.py index 73a5423..0d637e4 100644 --- a/cosyvoice/flow/DiT/dit.py +++ b/cosyvoice/flow/DiT/dit.py @@ -115,6 +115,7 @@ class DiT(nn.Module): mu_dim=None, long_skip_connection=False, spk_dim=None, + out_channels=None, static_chunk_size=50, num_decoding_left_chunks=2 ): @@ -137,6 +138,7 @@ class DiT(nn.Module): self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) + self.out_channels = out_channels self.static_chunk_size = static_chunk_size self.num_decoding_left_chunks = num_decoding_left_chunks diff --git a/cosyvoice/utils/class_utils.py b/cosyvoice/utils/class_utils.py index c49de00..c52fec4 100644 --- a/cosyvoice/utils/class_utils.py +++ b/cosyvoice/utils/class_utils.py @@ -33,8 +33,8 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention, from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling from cosyvoice.llm.llm import TransformerLM, Qwen2LM -from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec -from cosyvoice.hifigan.generator import HiFTGenerator +from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT +from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model @@ -80,4 +80,6 @@ def get_model_type(configs): return CosyVoiceModel if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): return CosyVoice2Model + if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator): + return CosyVoice2Model raise TypeError('No valid model type found!')