This commit is contained in:
lyuxiang.lx
2025-08-21 11:45:36 +08:00
parent dd2d926147
commit 8c96081f94
5 changed files with 13 additions and 6 deletions

View File

@@ -27,7 +27,7 @@ from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR)) sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2, CosyVoice3
from cosyvoice.utils.file_utils import logging from cosyvoice.utils.file_utils import logging
@@ -64,7 +64,10 @@ def main():
try: try:
model = CosyVoice2(args.model_dir) model = CosyVoice2(args.model_dir)
except Exception: except Exception:
raise TypeError('no valid model_type!') try:
model = CosyVoice3(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
# 1. export flow decoder estimator # 1. export flow decoder estimator
estimator = model.model.flow.decoder.estimator estimator = model.model.flow.decoder.estimator

View File

@@ -221,7 +221,7 @@ class CosyVoice3(CosyVoice):
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16) self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir), self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir), '{}/flow.pt'.format(model_dir),
'{}/bigvgan.pt'.format(model_dir)) '{}/hift.pt'.format(model_dir))
if load_vllm: if load_vllm:
self.model.load_vllm('{}/vllm'.format(model_dir)) self.model.load_vllm('{}/vllm'.format(model_dir))
if load_jit: if load_jit:

View File

@@ -447,7 +447,7 @@ class CosyVoice3Model(CosyVoice2Model):
if speed != 1.0: if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode' assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel)
if self.hift_cache_dict[uuid] is not None: if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech return tts_speech

View File

@@ -115,6 +115,7 @@ class DiT(nn.Module):
mu_dim=None, mu_dim=None,
long_skip_connection=False, long_skip_connection=False,
spk_dim=None, spk_dim=None,
out_channels=None,
static_chunk_size=50, static_chunk_size=50,
num_decoding_left_chunks=2 num_decoding_left_chunks=2
): ):
@@ -137,6 +138,7 @@ class DiT(nn.Module):
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim) self.proj_out = nn.Linear(dim, mel_dim)
self.out_channels = out_channels
self.static_chunk_size = static_chunk_size self.static_chunk_size = static_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks self.num_decoding_left_chunks = num_decoding_left_chunks

View File

@@ -33,8 +33,8 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
from cosyvoice.llm.llm import TransformerLM, Qwen2LM from cosyvoice.llm.llm import TransformerLM, Qwen2LM
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
from cosyvoice.hifigan.generator import HiFTGenerator from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
@@ -80,4 +80,6 @@ def get_model_type(configs):
return CosyVoiceModel return CosyVoiceModel
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
return CosyVoice2Model return CosyVoice2Model
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
return CosyVoice2Model
raise TypeError('No valid model type found!') raise TypeError('No valid model type found!')