mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update
This commit is contained in:
@@ -27,7 +27,7 @@ from tqdm import tqdm
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2, CosyVoice3
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
|
||||
|
||||
@@ -64,7 +64,10 @@ def main():
|
||||
try:
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
try:
|
||||
model = CosyVoice3(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
|
||||
# 1. export flow decoder estimator
|
||||
estimator = model.model.flow.decoder.estimator
|
||||
|
||||
@@ -221,7 +221,7 @@ class CosyVoice3(CosyVoice):
|
||||
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/bigvgan.pt'.format(model_dir))
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_vllm:
|
||||
self.model.load_vllm('{}/vllm'.format(model_dir))
|
||||
if load_jit:
|
||||
|
||||
@@ -447,7 +447,7 @@ class CosyVoice3Model(CosyVoice2Model):
|
||||
if speed != 1.0:
|
||||
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
return tts_speech
|
||||
@@ -115,6 +115,7 @@ class DiT(nn.Module):
|
||||
mu_dim=None,
|
||||
long_skip_connection=False,
|
||||
spk_dim=None,
|
||||
out_channels=None,
|
||||
static_chunk_size=50,
|
||||
num_decoding_left_chunks=2
|
||||
):
|
||||
@@ -137,6 +138,7 @@ class DiT(nn.Module):
|
||||
|
||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
self.out_channels = out_channels
|
||||
self.static_chunk_size = static_chunk_size
|
||||
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
|
||||
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
||||
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
||||
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
|
||||
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
|
||||
from cosyvoice.hifigan.generator import HiFTGenerator
|
||||
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
|
||||
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||
|
||||
|
||||
@@ -80,4 +80,6 @@ def get_model_type(configs):
|
||||
return CosyVoiceModel
|
||||
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||
return CosyVoice2Model
|
||||
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
|
||||
return CosyVoice2Model
|
||||
raise TypeError('No valid model type found!')
|
||||
|
||||
Reference in New Issue
Block a user