mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update
This commit is contained in:
@@ -27,7 +27,7 @@ from tqdm import tqdm
|
|||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2, CosyVoice3
|
||||||
from cosyvoice.utils.file_utils import logging
|
from cosyvoice.utils.file_utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -64,7 +64,10 @@ def main():
|
|||||||
try:
|
try:
|
||||||
model = CosyVoice2(args.model_dir)
|
model = CosyVoice2(args.model_dir)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise TypeError('no valid model_type!')
|
try:
|
||||||
|
model = CosyVoice3(args.model_dir)
|
||||||
|
except Exception:
|
||||||
|
raise TypeError('no valid model_type!')
|
||||||
|
|
||||||
# 1. export flow decoder estimator
|
# 1. export flow decoder estimator
|
||||||
estimator = model.model.flow.decoder.estimator
|
estimator = model.model.flow.decoder.estimator
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ class CosyVoice3(CosyVoice):
|
|||||||
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/bigvgan.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
if load_vllm:
|
if load_vllm:
|
||||||
self.model.load_vllm('{}/vllm'.format(model_dir))
|
self.model.load_vllm('{}/vllm'.format(model_dir))
|
||||||
if load_jit:
|
if load_jit:
|
||||||
|
|||||||
@@ -447,7 +447,7 @@ class CosyVoice3Model(CosyVoice2Model):
|
|||||||
if speed != 1.0:
|
if speed != 1.0:
|
||||||
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel)
|
||||||
if self.hift_cache_dict[uuid] is not None:
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
return tts_speech
|
return tts_speech
|
||||||
@@ -115,6 +115,7 @@ class DiT(nn.Module):
|
|||||||
mu_dim=None,
|
mu_dim=None,
|
||||||
long_skip_connection=False,
|
long_skip_connection=False,
|
||||||
spk_dim=None,
|
spk_dim=None,
|
||||||
|
out_channels=None,
|
||||||
static_chunk_size=50,
|
static_chunk_size=50,
|
||||||
num_decoding_left_chunks=2
|
num_decoding_left_chunks=2
|
||||||
):
|
):
|
||||||
@@ -137,6 +138,7 @@ class DiT(nn.Module):
|
|||||||
|
|
||||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||||
self.proj_out = nn.Linear(dim, mel_dim)
|
self.proj_out = nn.Linear(dim, mel_dim)
|
||||||
|
self.out_channels = out_channels
|
||||||
self.static_chunk_size = static_chunk_size
|
self.static_chunk_size = static_chunk_size
|
||||||
self.num_decoding_left_chunks = num_decoding_left_chunks
|
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||||
|
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
|
|||||||
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
||||||
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
||||||
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
|
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
|
||||||
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
|
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
|
||||||
from cosyvoice.hifigan.generator import HiFTGenerator
|
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
|
||||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||||
|
|
||||||
|
|
||||||
@@ -80,4 +80,6 @@ def get_model_type(configs):
|
|||||||
return CosyVoiceModel
|
return CosyVoiceModel
|
||||||
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||||
return CosyVoice2Model
|
return CosyVoice2Model
|
||||||
|
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
|
||||||
|
return CosyVoice2Model
|
||||||
raise TypeError('No valid model type found!')
|
raise TypeError('No valid model type found!')
|
||||||
|
|||||||
Reference in New Issue
Block a user