This commit is contained in:
lyuxiang.lx
2024-12-12 16:46:28 +08:00
parent 2345ce6be2
commit c693039d14
6 changed files with 145 additions and 71 deletions

View File

@@ -1,4 +1,8 @@
# CosyVoice # CosyVoice
## 👉🏻 [CosyVoice2 Demos](https://funaudiollm.github.io/cosyvoice2/) 👈🏻
[[CosyVoice2 Paper](https://fun-audio-llm.github.io/pdf/CosyVoice_v1.pdf)][[CosyVoice2 Studio](https://www.modelscope.cn/studios/iic/CosyVoice-300M)]
## 👉🏻 [CosyVoice Demos](https://fun-audio-llm.github.io/) 👈🏻 ## 👉🏻 [CosyVoice Demos](https://fun-audio-llm.github.io/) 👈🏻
[[CosyVoice Paper](https://fun-audio-llm.github.io/pdf/CosyVoice_v1.pdf)][[CosyVoice Studio](https://www.modelscope.cn/studios/iic/CosyVoice-300M)][[CosyVoice Code](https://github.com/FunAudioLLM/CosyVoice)] [[CosyVoice Paper](https://fun-audio-llm.github.io/pdf/CosyVoice_v1.pdf)][[CosyVoice Studio](https://www.modelscope.cn/studios/iic/CosyVoice-300M)][[CosyVoice Code](https://github.com/FunAudioLLM/CosyVoice)]
@@ -6,6 +10,11 @@ For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVo
## Roadmap ## Roadmap
- [x] 2024/12
- [x] CosyVoice2-0.5B model release
- [x] CosyVoice2-0.5B streaming inference with no quality degradation
- [x] 2024/07 - [x] 2024/07
- [x] Flow matching training support - [x] Flow matching training support
@@ -24,9 +33,8 @@ For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVo
- [ ] TBD - [ ] TBD
- [ ] 25hz llama based llm model which supports lora finetune - [ ] CosyVoice2-0.5B bistream inference support
- [ ] Support more instruction mode - [ ] CosyVoice2-0.5B training and finetune recipie
- [ ] Music generation
- [ ] CosyVoice-500M trained with more multi-lingual data - [ ] CosyVoice-500M trained with more multi-lingual data
- [ ] More... - [ ] More...
@@ -46,7 +54,7 @@ git submodule update --init --recursive
- Create Conda env: - Create Conda env:
``` sh ``` sh
conda create -n cosyvoice python=3.8 conda create -n cosyvoice python=3.10
conda activate cosyvoice conda activate cosyvoice
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform. # pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
conda install -y -c conda-forge pynini==2.1.5 conda install -y -c conda-forge pynini==2.1.5
@@ -68,6 +76,7 @@ If you are expert in this field, and you are only interested in training your ow
``` python ``` python
# SDK模型下载 # SDK模型下载
from modelscope import snapshot_download from modelscope import snapshot_download
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M') snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz') snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT') snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
@@ -78,6 +87,7 @@ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice
``` sh ``` sh
# git模型下载请确保已安装git lfs # git模型下载请确保已安装git lfs
mkdir -p pretrained_models mkdir -p pretrained_models
git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
@@ -97,9 +107,11 @@ pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
**Basic Usage** **Basic Usage**
For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model. For zero_shot/cross_lingual inference, please use `CosyVoice2-0.5B` or `CosyVoice-300M` model.
For sft inference, please use `CosyVoice-300M-SFT` model. For sft inference, please use `CosyVoice-300M-SFT` model.
For instruct inference, please use `CosyVoice-300M-Instruct` model. For instruct inference, please use `CosyVoice-300M-Instruct` model.
We strongly recommend using `CosyVoice2-0.5B` model for better streaming performance.
First, add `third_party/Matcha-TTS` to your `PYTHONPATH`. First, add `third_party/Matcha-TTS` to your `PYTHONPATH`.
``` sh ``` sh
@@ -107,10 +119,18 @@ export PYTHONPATH=third_party/Matcha-TTS
``` ```
``` python ``` python
from cosyvoice.cli.cosyvoice import CosyVoice from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav from cosyvoice.utils.file_utils import load_wav
import torchaudio import torchaudio
## cosyvoice2 usage
cosyvoice2 = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, load_trt=False)
# sft usage
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
for i, j in enumerate(cosyvoice2.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=True)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice2.sample_rate)
## cosyvoice usage
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True) cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
# sft usage # sft usage
print(cosyvoice.list_avaliable_spks()) print(cosyvoice.list_avaliable_spks())
@@ -189,5 +209,16 @@ You can also scan the QR code to join our official Dingding chat group.
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec). 4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet). 5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
## Citations
``` bibtex
@article{du2024cosyvoice,
title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
journal={arXiv preprint arXiv:2407.05407},
year={2024}
}
```
## Disclaimer ## Disclaimer
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal. The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.

View File

@@ -38,6 +38,7 @@ class CosyVoice:
'{}/spk2info.pt'.format(model_dir), '{}/spk2info.pt'.format(model_dir),
instruct, instruct,
configs['allowed_special']) configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (fp16 is True or load_jit is True): if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
load_jit = False load_jit = False
fp16 = False fp16 = False
@@ -64,7 +65,7 @@ class CosyVoice:
start_time = time.time() start_time = time.time()
logging.info('synthesis text {}'.format(i)) logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed): for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050 speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output yield model_output
start_time = time.time() start_time = time.time()
@@ -74,11 +75,11 @@ class CosyVoice:
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
if len(i) < 0.5 * len(prompt_text): if len(i) < 0.5 * len(prompt_text):
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text)) logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k) model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
start_time = time.time() start_time = time.time()
logging.info('synthesis text {}'.format(i)) logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed): for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050 speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output yield model_output
start_time = time.time() start_time = time.time()
@@ -87,11 +88,11 @@ class CosyVoice:
if self.frontend.instruct is True: if self.frontend.instruct is True:
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir)) raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k) model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
start_time = time.time() start_time = time.time()
logging.info('synthesis text {}'.format(i)) logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed): for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050 speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output yield model_output
start_time = time.time() start_time = time.time()
@@ -105,23 +106,23 @@ class CosyVoice:
start_time = time.time() start_time = time.time()
logging.info('synthesis text {}'.format(i)) logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed): for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050 speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output yield model_output
start_time = time.time() start_time = time.time()
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k) model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
start_time = time.time() start_time = time.time()
for model_output in self.model.vc(**model_input, stream=stream, speed=speed): for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050 speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output yield model_output
start_time = time.time() start_time = time.time()
class CosyVoice2(CosyVoice): class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True): def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
instruct = True if '-Instruct' in model_dir else False instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir self.model_dir = model_dir
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
@@ -135,18 +136,21 @@ class CosyVoice2(CosyVoice):
'{}/spk2info.pt'.format(model_dir), '{}/spk2info.pt'.format(model_dir),
instruct, instruct,
configs['allowed_special']) configs['allowed_special'])
if torch.cuda.is_available() is False and (fp16 is True or load_jit is True): self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and load_jit is True:
load_jit = False load_jit = False
fp16 = False logging.warning('cpu do not support jit, force set to False')
logging.warning('cpu do not support fp16 and jit, force set to False') self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir), self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir), '{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir)) '{}/hift.pt'.format(model_dir))
if load_jit: if load_jit:
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
'{}/llm.llm.fp16.zip'.format(model_dir), if load_trt is True and load_onnx is True:
'{}/flow.encoder.fp32.zip'.format(model_dir)) load_onnx = False
logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
if load_onnx: if load_onnx:
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir)) self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
del configs del configs

View File

@@ -142,11 +142,11 @@ class CosyVoiceFrontEnd:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding} model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input return model_input
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k): def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text) prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k) prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050) speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k) speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
embedding = self._extract_spk_embedding(prompt_speech_16k) embedding = self._extract_spk_embedding(prompt_speech_16k)
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
@@ -157,8 +157,8 @@ class CosyVoiceFrontEnd:
'llm_embedding': embedding, 'flow_embedding': embedding} 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input return model_input
def frontend_cross_lingual(self, tts_text, prompt_speech_16k): def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k) model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
# in cross lingual mode, we remove prompt in llm # in cross lingual mode, we remove prompt in llm
del model_input['prompt_text'] del model_input['prompt_text']
del model_input['prompt_text_len'] del model_input['prompt_text_len']
@@ -175,10 +175,10 @@ class CosyVoiceFrontEnd:
model_input['prompt_text_len'] = instruct_text_token_len model_input['prompt_text_len'] = instruct_text_token_len
return model_input return model_input
def frontend_vc(self, source_speech_16k, prompt_speech_16k): def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k) prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k) prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050) prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
embedding = self._extract_spk_embedding(prompt_speech_16k) embedding = self._extract_spk_embedding(prompt_speech_16k)
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k) source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len, model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,

View File

@@ -261,16 +261,15 @@ class CosyVoice2Model:
def __init__(self, def __init__(self,
llm: torch.nn.Module, llm: torch.nn.Module,
flow: torch.nn.Module, flow: torch.nn.Module,
hift: torch.nn.Module, hift: torch.nn.Module):
fp16: bool):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm self.llm = llm
self.flow = flow self.flow = flow
self.hift = hift self.hift = hift
self.fp16 = fp16 self.token_hop_len = 2 * self.flow.input_frame_rate
self.token_min_hop_len = 1 * self.flow.input_frame_rate # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
self.token_max_hop_len = 2 * self.flow.input_frame_rate self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
self.token_right_context = self.flow.encoder.pre_lookahead_layer.pre_lookahead_len self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
# hift cache # hift cache
self.mel_cache_len = 8 self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480) self.source_cache_len = int(self.mel_cache_len * 480)
@@ -278,7 +277,6 @@ class CosyVoice2Model:
self.speech_window = np.hamming(2 * self.source_cache_len) self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related # rtf and decoding related
self.stream_scale_factor = 1 self.stream_scale_factor = 1
assert self.stream_scale_factor == 1, 'fix stream_scale_factor to 1 as we haven\'t implement cache in flow matching yet, this constraint will be loosen in the future'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock() self.lock = threading.Lock()
# dict used to store session related variable # dict used to store session related variable
@@ -293,17 +291,13 @@ class CosyVoice2Model:
self.llm.half() self.llm.half()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
self.flow.to(self.device).eval() self.flow.to(self.device).eval()
self.flow.decoder.fp16 = False
# in case hift_model is a hifigan model # in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True) self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval() self.hift.to(self.device).eval()
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): def load_jit(self, flow_encoder_model):
assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
self.llm.text_encoder = llm_text_encoder
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
self.llm.llm = llm_llm
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder self.flow.encoder = flow_encoder
@@ -316,6 +310,14 @@ class CosyVoice2Model:
del self.flow.decoder.estimator del self.flow.decoder.estimator
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers) self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
def load_trt(self, flow_decoder_estimator_model):
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
self.flow.decoder.fp16 = True
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
if self.fp16 is True: if self.fp16 is True:
llm_embedding = llm_embedding.half() llm_embedding = llm_embedding.half()
@@ -339,7 +341,7 @@ class CosyVoice2Model:
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device), embedding=embedding.to(self.device),
finalize=finalize) finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.encoder.up_layer.stride:] tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache # append hift cache
if self.hift_cache_dict[uuid] is not None: if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
@@ -377,13 +379,11 @@ class CosyVoice2Model:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start() p.start()
if stream is True: if stream is True:
token_hop_len, token_offset = self.token_min_hop_len, 0 token_offset = 0
self.flow.encoder.static_chunk_size = self.token_min_hop_len
self.flow.decoder.estimator.static_chunk_size = self.token_min_hop_len * self.flow.encoder.up_layer.stride
while True: while True:
time.sleep(0.1) time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= token_hop_len + self.token_right_context: if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + token_hop_len + self.token_right_context]) \ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]) \
.unsqueeze(dim=0) .unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token, this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token, prompt_token=flow_prompt_speech_token,
@@ -392,11 +392,9 @@ class CosyVoice2Model:
uuid=this_uuid, uuid=this_uuid,
token_offset=token_offset, token_offset=token_offset,
finalize=False) finalize=False)
token_offset += token_hop_len token_offset += self.token_hop_len
yield {'tts_speech': this_tts_speech.cpu()} yield {'tts_speech': this_tts_speech.cpu()}
# increase token_hop_len for better speech quality if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < token_hop_len + self.token_right_context:
break break
p.join() p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
@@ -412,14 +410,13 @@ class CosyVoice2Model:
else: else:
# deal with all tokens # deal with all tokens
p.join() p.join()
self.flow.encoder.static_chunk_size = 0
self.flow.decoder.estimator.static_chunk_size = 0
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token, this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token, prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat, prompt_feat=prompt_speech_feat,
embedding=flow_embedding, embedding=flow_embedding,
uuid=this_uuid, uuid=this_uuid,
token_offset=0,
finalize=True, finalize=True,
speed=speed) speed=speed)
yield {'tts_speech': this_tts_speech.cpu()} yield {'tts_speech': this_tts_speech.cpu()}

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import onnxruntime
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM from matcha.models.components.flow_matching import BASECFM
@@ -88,15 +89,25 @@ class ConditionalCFM(BASECFM):
# Or in future might add like a return_all_steps flag # Or in future might add like a return_all_steps flag
sol = [] sol = []
if self.inference_cfg_rate > 0:
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
else:
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
for step in range(1, len(t_span)): for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox # Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0: if self.inference_cfg_rate > 0:
x_in = torch.concat([x, x], dim=0) x_in[:] = x
mask_in = torch.concat([mask, mask], dim=0) mask_in[:] = mask
mu_in = torch.concat([mu, torch.zeros_like(mu).to(x.device)], dim=0) mu_in[0] = mu
t_in = torch.concat([t, t], dim=0) t_in[:] = t.unsqueeze(0)
spks_in = torch.concat([spks, torch.zeros_like(spks).to(x.device)], dim=0) if spks is not None else None spks_in[0] = spks
cond_in = torch.concat([cond, torch.zeros_like(cond).to(x.device)], dim=0) if cond is not None else None cond_in[0] = cond
else: else:
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
dphi_dt = self.forward_estimator( dphi_dt = self.forward_estimator(
@@ -114,22 +125,53 @@ class ConditionalCFM(BASECFM):
if step < len(t_span) - 1: if step < len(t_span) - 1:
dt = t_span[step + 1] - t dt = t_span[step + 1] - t
return sol[-1] return sol[-1].float()
def forward_estimator(self, x, mask, mu, t, spks, cond): def forward_estimator(self, x, mask, mu, t, spks, cond):
if isinstance(self.estimator, torch.nn.Module): if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond) return self.estimator.forward(x, mask, mu, t, spks, cond)
else: elif isinstance(self.estimator, onnxruntime.InferenceSession):
ort_inputs = { ort_inputs = {
'x': x.cpu().numpy(), 'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(), 'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(), 'mu': mu.cpu().numpy(),
't': t.cpu().numpy(), 't': t.cpu().numpy(),
'spks': spks.cpu().numpy(), 'spk': spks.cpu().numpy(),
'cond': cond.cpu().numpy() 'cond': cond.cpu().numpy(),
'mask_rand': torch.randn(1, 1, 1).numpy()
} }
output = self.estimator.run(None, ort_inputs)[0] output = self.estimator.run(None, ort_inputs)[0]
return torch.tensor(output, dtype=x.dtype, device=x.device) return torch.tensor(output, dtype=x.dtype, device=x.device)
else:
if not x.is_contiguous():
x = x.contiguous()
if not mask.is_contiguous():
mask = mask.contiguous()
if not mu.is_contiguous():
mu = mu.contiguous()
if not t.is_contiguous():
t = t.contiguous()
if not spks.is_contiguous():
spks = spks.contiguous()
if not cond.is_contiguous():
cond = cond.contiguous()
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spk', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask_rand', (1, 1, 1))
# run trt engine
self.estimator.execute_v2([x.data_ptr(),
mask.data_ptr(),
mu.data_ptr(),
t.data_ptr(),
spks.data_ptr(),
cond.data_ptr(),
torch.randn(1, 1, 1).to(x.device).data_ptr(),
x.data_ptr()])
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None): def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss """Computes diffusion loss
@@ -199,7 +241,8 @@ class CausalConditionalCFM(ConditionalCFM):
""" """
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
z[:] = 0 if self.sp16 is True:
z = z.half()
# fix prompt and overlap part mu and z # fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine': if self.t_scheduler == 'cosine':

View File

@@ -1,5 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/torch_stable.html --extra-index-url https://download.pytorch.org/whl/cu121
conformer==0.3.2
deepspeed==0.14.2; sys_platform == 'linux' deepspeed==0.14.2; sys_platform == 'linux'
diffusers==0.27.2 diffusers==0.27.2
gdown==5.1.0 gdown==5.1.0
@@ -26,8 +25,8 @@ rich==13.7.1
soundfile==0.12.1 soundfile==0.12.1
tensorboard==2.14.0 tensorboard==2.14.0
tensorrt-cu12==10.0.1 tensorrt-cu12==10.0.1
torch==2.3.1+cu121 torch==2.3.1
torchaudio==2.3.1+cu121 torchaudio==2.3.1
uvicorn==0.30.0 uvicorn==0.30.0
wget==3.2 wget==3.2
fastapi==0.111.0 fastapi==0.111.0