mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
Compare commits
47 Commits
flow_cache
...
v2.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8555549e88 | ||
|
|
3047591fad | ||
|
|
5a00aefa20 | ||
|
|
35abb1f3b5 | ||
|
|
21a5efd8ae | ||
|
|
44316c3475 | ||
|
|
116c99bf39 | ||
|
|
6eaef42126 | ||
|
|
525531d8a3 | ||
|
|
c788bca1a6 | ||
|
|
ff9694dc2c | ||
|
|
4505924608 | ||
|
|
46dfe0439b | ||
|
|
63856565f3 | ||
|
|
cc234bd322 | ||
|
|
98ef35b1c0 | ||
|
|
fca1df3ea7 | ||
|
|
c939c80480 | ||
|
|
5bf6befd70 | ||
|
|
1c7976779b | ||
|
|
a8e1774e82 | ||
|
|
1f50ae259b | ||
|
|
793a7fe6ad | ||
|
|
79b2ea818b | ||
|
|
7d9d84d32d | ||
|
|
9b052a94c4 | ||
|
|
6dd68b9d5e | ||
|
|
9f55c5af8f | ||
|
|
b6c5f9dfd2 | ||
|
|
cbfed4a9ee | ||
|
|
54d21b40f0 | ||
|
|
c3250c222f | ||
|
|
82219cdd27 | ||
|
|
4159a18469 | ||
|
|
3e12bb86bd | ||
|
|
cbfbe2bc33 | ||
|
|
3660da4a19 | ||
|
|
3c921daede | ||
|
|
68100c267a | ||
|
|
88f467a8ac | ||
|
|
038ff9f353 | ||
|
|
65ad448714 | ||
|
|
a442317d17 | ||
|
|
7f8bea2669 | ||
|
|
6d876f573c | ||
|
|
634edfadf0 | ||
|
|
a22873e360 |
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -52,5 +52,5 @@ jobs:
|
||||
set -eux
|
||||
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
|
||||
flake8 --version
|
||||
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||
if [ $? != 0 ]; then exit 1; fi
|
||||
119
README.md
119
README.md
@@ -1,6 +1,9 @@
|
||||
[](https://github.com/Akshay090/svg-banners)
|
||||
|
||||
## 👉🏻 CosyVoice 👈🏻
|
||||
|
||||
**CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/abs/2505.17589); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
|
||||
|
||||
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
|
||||
|
||||
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
|
||||
@@ -26,6 +29,14 @@
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] 2025/07
|
||||
|
||||
- [x] release cosyvoice 3.0 eval set
|
||||
|
||||
- [x] 2025/05
|
||||
|
||||
- [x] add cosyvoice 2.0 vllm support
|
||||
|
||||
- [x] 2024/12
|
||||
|
||||
- [x] 25hz cosyvoice 2.0 released
|
||||
@@ -49,34 +60,32 @@
|
||||
|
||||
## Install
|
||||
|
||||
**Clone and install**
|
||||
### Clone and install
|
||||
|
||||
- Clone the repo
|
||||
``` sh
|
||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||
# If you failed to clone submodule due to network failures, please run following command until success
|
||||
cd CosyVoice
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
``` sh
|
||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||
# If you failed to clone the submodule due to network failures, please run the following command until success
|
||||
cd CosyVoice
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
|
||||
- Create Conda env:
|
||||
|
||||
``` sh
|
||||
conda create -n cosyvoice -y python=3.10
|
||||
conda activate cosyvoice
|
||||
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
|
||||
conda install -y -c conda-forge pynini==2.1.5
|
||||
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
``` sh
|
||||
conda create -n cosyvoice -y python=3.10
|
||||
conda activate cosyvoice
|
||||
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
|
||||
# If you encounter sox compatibility issues
|
||||
# ubuntu
|
||||
sudo apt-get install sox libsox-dev
|
||||
# centos
|
||||
sudo yum install sox sox-devel
|
||||
```
|
||||
# If you encounter sox compatibility issues
|
||||
# ubuntu
|
||||
sudo apt-get install sox libsox-dev
|
||||
# centos
|
||||
sudo yum install sox sox-devel
|
||||
```
|
||||
|
||||
**Model download**
|
||||
### Model download
|
||||
|
||||
We strongly recommend that you download our pretrained `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
|
||||
|
||||
@@ -100,9 +109,9 @@ git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_m
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
||||
```
|
||||
|
||||
Optionally, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
|
||||
Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
|
||||
|
||||
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
|
||||
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use wetext by default.
|
||||
|
||||
``` sh
|
||||
cd pretrained_models/CosyVoice-ttsfrd/
|
||||
@@ -111,10 +120,10 @@ pip install ttsfrd_dependency-0.1-py3-none-any.whl
|
||||
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
|
||||
```
|
||||
|
||||
**Basic Usage**
|
||||
### Basic Usage
|
||||
|
||||
We strongly recommend using `CosyVoice2-0.5B` for better performance.
|
||||
Follow code below for detailed usage of each model.
|
||||
Follow the code below for detailed usage of each model.
|
||||
|
||||
``` python
|
||||
import sys
|
||||
@@ -124,9 +133,9 @@ from cosyvoice.utils.file_utils import load_wav
|
||||
import torchaudio
|
||||
```
|
||||
|
||||
**CosyVoice2 Usage**
|
||||
#### CosyVoice2 Usage
|
||||
```python
|
||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)
|
||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, load_vllm=False, fp16=False)
|
||||
|
||||
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
||||
# zero_shot usage
|
||||
@@ -159,7 +168,19 @@ for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
```
|
||||
|
||||
**CosyVoice Usage**
|
||||
#### CosyVoice2 vllm Usage
|
||||
If you want to use vllm for inference, please install `vllm==v0.9.0`. Older vllm version do not support CosyVoice2 inference.
|
||||
|
||||
Notice that `vllm==v0.9.0` has a lot of specific requirements, for example `torch==2.7.0`. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
|
||||
|
||||
``` sh
|
||||
conda create -n cosyvoice_vllm --clone cosyvoice
|
||||
conda activate cosyvoice_vllm
|
||||
pip install vllm==v0.9.0 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
python vllm_example.py
|
||||
```
|
||||
|
||||
#### CosyVoice Usage
|
||||
```python
|
||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
|
||||
# sft usage
|
||||
@@ -189,7 +210,7 @@ for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
```
|
||||
|
||||
**Start web demo**
|
||||
#### Start web demo
|
||||
|
||||
You can use our web demo page to get familiar with CosyVoice quickly.
|
||||
|
||||
@@ -200,14 +221,14 @@ Please see the demo website for details.
|
||||
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
|
||||
```
|
||||
|
||||
**Advanced Usage**
|
||||
#### Advanced Usage
|
||||
|
||||
For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
|
||||
For advanced users, we have provided training and inference scripts in `examples/libritts/cosyvoice/run.sh`.
|
||||
|
||||
**Build for deployment**
|
||||
#### Build for deployment
|
||||
|
||||
Optionally, if you want service deployment,
|
||||
you can run following steps.
|
||||
You can run the following steps.
|
||||
|
||||
``` sh
|
||||
cd runtime/python
|
||||
@@ -237,5 +258,39 @@ You can also scan the QR code to join our official Dingding chat group.
|
||||
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
|
||||
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
|
||||
|
||||
## Citations
|
||||
|
||||
``` bibtex
|
||||
@article{du2024cosyvoice,
|
||||
title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
|
||||
author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
|
||||
journal={arXiv preprint arXiv:2407.05407},
|
||||
year={2024}
|
||||
}
|
||||
|
||||
@article{du2024cosyvoice,
|
||||
title={Cosyvoice 2: Scalable streaming speech synthesis with large language models},
|
||||
author={Du, Zhihao and Wang, Yuxuan and Chen, Qian and Shi, Xian and Lv, Xiang and Zhao, Tianyu and Gao, Zhifu and Yang, Yexin and Gao, Changfeng and Wang, Hui and others},
|
||||
journal={arXiv preprint arXiv:2412.10117},
|
||||
year={2024}
|
||||
}
|
||||
|
||||
@article{du2025cosyvoice,
|
||||
title={CosyVoice 3: Towards In-the-wild Speech Generation via Scaling-up and Post-training},
|
||||
author={Du, Zhihao and Gao, Changfeng and Wang, Yuxuan and Yu, Fan and Zhao, Tianyu and Wang, Hao and Lv, Xiang and Wang, Hui and Shi, Xian and An, Keyu and others},
|
||||
journal={arXiv preprint arXiv:2505.17589},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@inproceedings{lyu2025build,
|
||||
title={Build LLM-Based Zero-Shot Streaming TTS System with Cosyvoice},
|
||||
author={Lyu, Xiang and Wang, Yuxuan and Zhao, Tianyu and Wang, Hao and Liu, Huadai and Du, Zhihao},
|
||||
booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
|
||||
pages={1--2},
|
||||
year={2025},
|
||||
organization={IEEE}
|
||||
}
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 94 KiB After Width: | Height: | Size: 94 KiB |
@@ -61,8 +61,7 @@ def main():
|
||||
model = CosyVoice(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
# NOTE set use_flow_cache=True when export jit for cache inference
|
||||
model = CosyVoice2(args.model_dir, use_flow_cache=True)
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
|
||||
@@ -93,9 +92,9 @@ def main():
|
||||
else:
|
||||
# 3. export flow encoder
|
||||
flow_encoder = model.model.flow.encoder
|
||||
script = get_optimized_script(flow_encoder, ['forward_chunk'])
|
||||
script = get_optimized_script(flow_encoder)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(flow_encoder.half(), ['forward_chunk'])
|
||||
script = get_optimized_script(flow_encoder.half())
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
|
||||
|
||||
@@ -62,135 +62,58 @@ def main():
|
||||
model = CosyVoice(args.model_dir)
|
||||
except Exception:
|
||||
try:
|
||||
# NOTE set use_flow_cache=True when export jit for cache inference
|
||||
model = CosyVoice2(args.model_dir, use_flow_cache=True)
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
|
||||
if not isinstance(model, CosyVoice2):
|
||||
# 1. export flow decoder estimator
|
||||
estimator = model.model.flow.decoder.estimator
|
||||
estimator.eval()
|
||||
# 1. export flow decoder estimator
|
||||
estimator = model.model.flow.decoder.estimator
|
||||
estimator.eval()
|
||||
|
||||
device = model.model.device
|
||||
batch_size, seq_len = 2, 256
|
||||
out_channels = model.model.flow.decoder.estimator.out_channels
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
||||
torch.onnx.export(
|
||||
estimator,
|
||||
(x, mask, mu, t, spks, cond),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
export_params=True,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
||||
output_names=['estimator_out'],
|
||||
dynamic_axes={
|
||||
'x': {2: 'seq_len'},
|
||||
'mask': {2: 'seq_len'},
|
||||
'mu': {2: 'seq_len'},
|
||||
'cond': {2: 'seq_len'},
|
||||
'estimator_out': {2: 'seq_len'},
|
||||
}
|
||||
)
|
||||
device = model.model.device
|
||||
batch_size, seq_len = 2, 256
|
||||
out_channels = model.model.flow.decoder.estimator.out_channels
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
||||
torch.onnx.export(
|
||||
estimator,
|
||||
(x, mask, mu, t, spks, cond),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
export_params=True,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
||||
output_names=['estimator_out'],
|
||||
dynamic_axes={
|
||||
'x': {2: 'seq_len'},
|
||||
'mask': {2: 'seq_len'},
|
||||
'mu': {2: 'seq_len'},
|
||||
'cond': {2: 'seq_len'},
|
||||
'estimator_out': {2: 'seq_len'},
|
||||
}
|
||||
)
|
||||
|
||||
# 2. test computation consistency
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
sess_options=option, providers=providers)
|
||||
# 2. test computation consistency
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
sess_options=option, providers=providers)
|
||||
|
||||
for _ in tqdm(range(10)):
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
||||
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
||||
ort_inputs = {
|
||||
'x': x.cpu().numpy(),
|
||||
'mask': mask.cpu().numpy(),
|
||||
'mu': mu.cpu().numpy(),
|
||||
't': t.cpu().numpy(),
|
||||
'spks': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy()
|
||||
}
|
||||
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
||||
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
||||
logging.info('successfully export estimator')
|
||||
else:
|
||||
# 1. export flow decoder estimator
|
||||
estimator = model.model.flow.decoder.estimator
|
||||
estimator.forward = estimator.forward_chunk
|
||||
estimator.eval()
|
||||
|
||||
device = model.model.device
|
||||
batch_size, seq_len = 2, 256
|
||||
out_channels = model.model.flow.decoder.estimator.out_channels
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
||||
cache = model.model.init_flow_cache()['decoder_cache']
|
||||
cache.pop('offset')
|
||||
cache = {k: v[0] for k, v in cache.items()}
|
||||
torch.onnx.export(
|
||||
estimator,
|
||||
(x, mask, mu, t, spks, cond,
|
||||
cache['down_blocks_conv_cache'],
|
||||
cache['down_blocks_kv_cache'],
|
||||
cache['mid_blocks_conv_cache'],
|
||||
cache['mid_blocks_kv_cache'],
|
||||
cache['up_blocks_conv_cache'],
|
||||
cache['up_blocks_kv_cache'],
|
||||
cache['final_blocks_conv_cache']),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
export_params=True,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
|
||||
'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
|
||||
output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
|
||||
'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
|
||||
dynamic_axes={
|
||||
'x': {2: 'seq_len'},
|
||||
'mask': {2: 'seq_len'},
|
||||
'mu': {2: 'seq_len'},
|
||||
'cond': {2: 'seq_len'},
|
||||
'down_blocks_kv_cache': {3: 'cache_in_len'},
|
||||
'mid_blocks_kv_cache': {3: 'cache_in_len'},
|
||||
'up_blocks_kv_cache': {3: 'cache_in_len'},
|
||||
'estimator_out': {2: 'seq_len'},
|
||||
'down_blocks_kv_cache_out': {3: 'cache_out_len'},
|
||||
'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
|
||||
'up_blocks_kv_cache_out': {3: 'cache_out_len'},
|
||||
}
|
||||
)
|
||||
|
||||
# 2. test computation consistency
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
sess_options=option, providers=providers)
|
||||
|
||||
for iter in tqdm(range(10)):
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
||||
cache = model.model.init_flow_cache()['decoder_cache']
|
||||
cache.pop('offset')
|
||||
cache = {k: v[0] for k, v in cache.items()}
|
||||
output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()})
|
||||
ort_inputs = {
|
||||
'x': x.cpu().numpy(),
|
||||
'mask': mask.cpu().numpy(),
|
||||
'mu': mu.cpu().numpy(),
|
||||
't': t.cpu().numpy(),
|
||||
'spks': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy(),
|
||||
}
|
||||
output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}})
|
||||
if iter == 0:
|
||||
# NOTE why can not pass first iteration check?
|
||||
continue
|
||||
for i, j in zip(output_pytorch, output_onnx):
|
||||
torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4)
|
||||
logging.info('successfully export estimator')
|
||||
for _ in tqdm(range(10)):
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
||||
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
||||
ort_inputs = {
|
||||
'x': x.cpu().numpy(),
|
||||
'mask': mask.cpu().numpy(),
|
||||
'mu': mu.cpu().numpy(),
|
||||
't': t.cpu().numpy(),
|
||||
'spks': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy()
|
||||
}
|
||||
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
||||
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
||||
logging.info('successfully export estimator')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -122,4 +122,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
|
||||
main()
|
||||
@@ -27,6 +27,7 @@ from hyperpyyaml import load_hyperpyyaml
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.errors import record
|
||||
|
||||
from cosyvoice.utils.losses import DPOLoss
|
||||
from cosyvoice.utils.executor import Executor
|
||||
from cosyvoice.utils.train_utils import (
|
||||
init_distributed,
|
||||
@@ -43,6 +44,7 @@ def get_args():
|
||||
choices=['torch_ddp', 'deepspeed'],
|
||||
help='Engine for paralleled training')
|
||||
parser.add_argument('--model', required=True, help='model which will be trained')
|
||||
parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--train_data', required=True, help='train data file')
|
||||
parser.add_argument('--cv_data', required=True, help='cv data file')
|
||||
@@ -73,6 +75,10 @@ def get_args():
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use automatic mixed precision training')
|
||||
parser.add_argument('--dpo',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use Direct Preference Optimization')
|
||||
parser.add_argument('--deepspeed.save_states',
|
||||
dest='save_states',
|
||||
default='model_only',
|
||||
@@ -113,7 +119,7 @@ def main():
|
||||
|
||||
# Get dataset & dataloader
|
||||
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
||||
init_dataset_and_dataloader(args, configs, gan)
|
||||
init_dataset_and_dataloader(args, configs, gan, args.dpo)
|
||||
|
||||
# Do some sanity checks and save config to arsg.model_dir
|
||||
configs = check_modify_and_save_config(args, configs)
|
||||
@@ -122,6 +128,8 @@ def main():
|
||||
writer = init_summarywriter(args)
|
||||
|
||||
# load checkpoint
|
||||
if args.dpo is True:
|
||||
configs[args.model].forward = configs[args.model].forward_dpo
|
||||
model = configs[args.model]
|
||||
start_step, start_epoch = 0, -1
|
||||
if args.checkpoint is not None:
|
||||
@@ -150,13 +158,25 @@ def main():
|
||||
info_dict['epoch'] = start_epoch
|
||||
save_model(model, 'init', info_dict)
|
||||
|
||||
# DPO related
|
||||
if args.dpo is True:
|
||||
ref_model = deepcopy(configs[args.model])
|
||||
state_dict = torch.load(args.ref_model, map_location='cpu')
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
|
||||
# NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
|
||||
ref_model = wrap_cuda_model(args, ref_model)
|
||||
else:
|
||||
ref_model, dpo_loss = None, None
|
||||
|
||||
# Get executor
|
||||
executor = Executor(gan=gan)
|
||||
executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
|
||||
executor.step = start_step
|
||||
|
||||
# Init scaler, used for pytorch amp mixed precision training
|
||||
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
||||
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
||||
|
||||
# Start training loop
|
||||
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
||||
executor.epoch = epoch
|
||||
@@ -167,7 +187,7 @@ def main():
|
||||
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||
writer, info_dict, scaler, group_join)
|
||||
else:
|
||||
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
|
||||
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
|
||||
dist.destroy_process_group(group_join)
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from cosyvoice.utils.class_utils import get_model_type
|
||||
|
||||
class CosyVoice:
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
@@ -59,6 +59,7 @@ class CosyVoice:
|
||||
if load_trt:
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
del configs
|
||||
|
||||
@@ -140,7 +141,7 @@ class CosyVoice:
|
||||
|
||||
class CosyVoice2(CosyVoice):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
@@ -162,15 +163,18 @@ class CosyVoice2(CosyVoice):
|
||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_vllm:
|
||||
self.model.load_vllm('{}/vllm'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
del configs
|
||||
|
||||
|
||||
@@ -28,9 +28,9 @@ try:
|
||||
import ttsfrd
|
||||
use_ttsfrd = True
|
||||
except ImportError:
|
||||
print("failed to import ttsfrd, use WeTextProcessing instead")
|
||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
||||
print("failed to import ttsfrd, use wetext instead")
|
||||
from wetext import Normalizer as ZhNormalizer
|
||||
from wetext import Normalizer as EnNormalizer
|
||||
use_ttsfrd = False
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
||||
@@ -68,7 +68,7 @@ class CosyVoiceFrontEnd:
|
||||
'failed to initialize ttsfrd resource'
|
||||
self.frd.set_lang_type('pinyinvg')
|
||||
else:
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
|
||||
self.en_tn_model = EnNormalizer()
|
||||
self.inflect_parser = inflect.engine()
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -21,7 +22,8 @@ from torch.nn import functional as F
|
||||
from contextlib import nullcontext
|
||||
import uuid
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||
from cosyvoice.utils.common import TrtContextWrapper
|
||||
|
||||
|
||||
class CosyVoiceModel:
|
||||
@@ -80,30 +82,28 @@ class CosyVoiceModel:
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model):
|
||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
||||
if os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_trt_kwargs(self):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)]
|
||||
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
|
||||
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
||||
input_names = ["x", "mask", "mu", "cond"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
with self.llm_context, torch.cuda.amp.autocast(self.fp16):
|
||||
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
||||
if isinstance(text, Generator):
|
||||
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
||||
assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
|
||||
for i in self.llm.inference_bistream(text=text,
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
@@ -118,7 +118,8 @@ class CosyVoiceModel:
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
embedding=llm_embedding.to(self.device),
|
||||
uuid=uuid):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
@@ -231,7 +232,9 @@ class CosyVoiceModel:
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
|
||||
class CosyVoice2Model(CosyVoiceModel):
|
||||
@@ -240,20 +243,17 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
llm: torch.nn.Module,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False,
|
||||
use_flow_cache: bool = False):
|
||||
fp16: bool = False):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
self.use_flow_cache = use_flow_cache
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
|
||||
# NOTE must matching training static_chunk_size
|
||||
self.token_hop_len = 25
|
||||
self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
|
||||
# hift cache
|
||||
self.mel_cache_len = 8
|
||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||
@@ -265,55 +265,35 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
def init_flow_cache(self):
|
||||
encoder_cache = {'offset': 0,
|
||||
'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
|
||||
'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
|
||||
'upsample_offset': 0,
|
||||
'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
|
||||
'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
|
||||
decoder_cache = {'offset': 0,
|
||||
'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
|
||||
'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
||||
'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
|
||||
'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
||||
'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
|
||||
'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
||||
'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
|
||||
if self.fp16 is True:
|
||||
for cache in [encoder_cache, decoder_cache]:
|
||||
for k, v in cache.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
cache[k] = v.half()
|
||||
cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
|
||||
return cache
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def get_trt_kwargs(self):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
|
||||
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
|
||||
max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
|
||||
input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
|
||||
assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape"
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
def load_vllm(self, model_dir):
|
||||
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
|
||||
from vllm import EngineArgs, LLMEngine
|
||||
engine_args = EngineArgs(model=model_dir,
|
||||
skip_tokenizer_init=True,
|
||||
enable_prompt_embeds=True,
|
||||
gpu_memory_utilization=0.2)
|
||||
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
|
||||
self.llm.lock = threading.Lock()
|
||||
del self.llm.llm.model.model.layers
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
cache=self.flow_cache_dict[uuid],
|
||||
finalize=finalize)
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
streaming=stream,
|
||||
finalize=finalize)
|
||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||
# append hift cache
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||
@@ -348,34 +328,30 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.flow_cache_dict[this_uuid] = self.init_flow_cache()
|
||||
if source_speech_token.shape[1] == 0:
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
else:
|
||||
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
|
||||
p.start()
|
||||
if stream is True:
|
||||
assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
|
||||
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
|
||||
flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):]
|
||||
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:]
|
||||
token_offset = 0
|
||||
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
||||
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
|
||||
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
token_offset=token_offset,
|
||||
uuid=this_uuid,
|
||||
stream=stream,
|
||||
finalize=False)
|
||||
# NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk
|
||||
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
|
||||
prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
|
||||
token_offset += this_token_hop_len
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
|
||||
break
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
@@ -384,18 +360,19 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
token_offset=token_offset,
|
||||
uuid=this_uuid,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
|
||||
p.join()
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
token_offset=0,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
@@ -404,5 +381,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
@@ -14,14 +14,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import json
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset
|
||||
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
||||
from cosyvoice.utils.file_utils import read_lists
|
||||
|
||||
|
||||
class Processor(IterableDataset):
|
||||
@@ -127,10 +126,9 @@ def Dataset(data_list_file,
|
||||
data_pipeline,
|
||||
mode='train',
|
||||
gan=False,
|
||||
dpo=False,
|
||||
shuffle=True,
|
||||
partition=True,
|
||||
tts_file='',
|
||||
prompt_utt2data=''):
|
||||
partition=True):
|
||||
""" Construct dataset from arguments
|
||||
|
||||
We have two shuffle stage in the Dataset. The first is global
|
||||
@@ -142,23 +140,12 @@ def Dataset(data_list_file,
|
||||
tokenizer (BaseTokenizer): tokenizer to tokenize
|
||||
partition(bool): whether to do data partition in terms of rank
|
||||
"""
|
||||
assert mode in ['train', 'inference']
|
||||
lists = read_lists(data_list_file)
|
||||
if mode == 'inference':
|
||||
with open(tts_file) as f:
|
||||
tts_data = json.load(f)
|
||||
utt2lists = read_json_lists(prompt_utt2data)
|
||||
# filter unnecessary file in inference mode
|
||||
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
||||
dataset = DataList(lists,
|
||||
shuffle=shuffle,
|
||||
partition=partition)
|
||||
if mode == 'inference':
|
||||
# map partial arg to parquet_opener func in inference mode
|
||||
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
||||
if gan is True:
|
||||
# map partial arg to padding func in gan mode
|
||||
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
||||
# map partial arg to padding func
|
||||
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
|
||||
for func in data_pipeline:
|
||||
dataset = Processor(dataset, func, mode=mode)
|
||||
return dataset
|
||||
|
||||
@@ -43,8 +43,6 @@ def parquet_opener(data, mode='train', tts_data={}):
|
||||
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
||||
df = df.to_pandas()
|
||||
for i in range(len(df)):
|
||||
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
||||
continue
|
||||
sample.update(dict(df.loc[i]))
|
||||
if mode == 'train':
|
||||
# NOTE do not return sample directly, must initialize a new dict
|
||||
@@ -100,6 +98,8 @@ def filter(data,
|
||||
continue
|
||||
if len(sample['speech_token']) == 0:
|
||||
continue
|
||||
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
||||
continue
|
||||
if num_frames != 0:
|
||||
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
||||
continue
|
||||
@@ -242,8 +242,6 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
||||
for sample in data:
|
||||
assert 'text' in sample
|
||||
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
||||
if mode == 'inference':
|
||||
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
||||
yield sample
|
||||
|
||||
|
||||
@@ -351,18 +349,15 @@ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
||||
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
||||
""" Wrapper for static/dynamic batch
|
||||
"""
|
||||
if mode == 'inference':
|
||||
return static_batch(data, 1)
|
||||
if batch_type == 'static':
|
||||
return static_batch(data, batch_size)
|
||||
elif batch_type == 'dynamic':
|
||||
return dynamic_batch(data, max_frames_in_batch)
|
||||
else:
|
||||
if batch_type == 'static':
|
||||
return static_batch(data, batch_size)
|
||||
elif batch_type == 'dynamic':
|
||||
return dynamic_batch(data, max_frames_in_batch)
|
||||
else:
|
||||
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
||||
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
||||
|
||||
|
||||
def padding(data, use_spk_embedding, mode='train', gan=False):
|
||||
def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
||||
""" Padding the data into training data
|
||||
|
||||
Args:
|
||||
@@ -424,16 +419,14 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
|
||||
# only gan train needs speech, delete it to save memory
|
||||
del batch["speech"]
|
||||
del batch["speech_len"]
|
||||
if mode == 'inference':
|
||||
tts_text = [sample[i]['tts_text'] for i in order]
|
||||
tts_index = [sample[i]['tts_index'] for i in order]
|
||||
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
||||
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
||||
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
||||
batch.update({'tts_text': tts_text,
|
||||
'tts_index': tts_index,
|
||||
'tts_text_token': tts_text_token,
|
||||
'tts_text_token_len': tts_text_token_len})
|
||||
if dpo is True:
|
||||
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
|
||||
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
||||
reject_speech_token = pad_sequence(reject_speech_token,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
batch['reject_speech_token'] = reject_speech_token
|
||||
batch['reject_speech_token_len'] = reject_speech_token_len
|
||||
if use_spk_embedding is True:
|
||||
batch["embedding"] = batch["spk_embedding"]
|
||||
else:
|
||||
|
||||
@@ -11,16 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Optional, Dict, Any
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import pack, rearrange, repeat
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
|
||||
from cosyvoice.utils.common import mask_to_bias
|
||||
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
||||
from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
|
||||
from matcha.models.components.transformer import BasicTransformerBlock
|
||||
|
||||
|
||||
class Transpose(torch.nn.Module):
|
||||
@@ -29,7 +28,7 @@ class Transpose(torch.nn.Module):
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.transpose(x, self.dim0, self.dim1)
|
||||
return x
|
||||
|
||||
@@ -57,15 +56,10 @@ class CausalConv1d(torch.nn.Conv1d):
|
||||
assert stride == 1
|
||||
self.causal_padding = kernel_size - 1
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if cache.size(2) == 0:
|
||||
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||
else:
|
||||
assert cache.size(2) == self.causal_padding
|
||||
x = torch.concat([cache, x], dim=2)
|
||||
cache = x[:, :, -self.causal_padding:]
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||
x = super(CausalConv1d, self).forward(x)
|
||||
return x, cache
|
||||
return x
|
||||
|
||||
|
||||
class CausalBlock1D(Block1D):
|
||||
@@ -79,11 +73,9 @@ class CausalBlock1D(Block1D):
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
output, cache = self.block[0](x * mask, cache)
|
||||
for i in range(1, len(self.block)):
|
||||
output = self.block[i](output)
|
||||
return output * mask, cache
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
class CausalResnetBlock1D(ResnetBlock1D):
|
||||
@@ -92,303 +84,6 @@ class CausalResnetBlock1D(ResnetBlock1D):
|
||||
self.block1 = CausalBlock1D(dim, dim_out)
|
||||
self.block2 = CausalBlock1D(dim_out, dim_out)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
|
||||
block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
h, block1_cache = self.block1(x, mask, block1_cache)
|
||||
h += self.mlp(time_emb).unsqueeze(-1)
|
||||
h, block2_cache = self.block2(h, mask, block2_cache)
|
||||
output = h + self.res_conv(x * mask)
|
||||
return output, block1_cache, block2_cache
|
||||
|
||||
|
||||
class CausalAttnProcessor2_0(AttnProcessor2_0):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CausalAttnProcessor2_0, self).__init__()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, torch.Tensor]:
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
|
||||
`scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key_cache = attn.to_k(encoder_hidden_states)
|
||||
value_cache = attn.to_v(encoder_hidden_states)
|
||||
# NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
|
||||
if cache.size(0) != 0:
|
||||
key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
|
||||
value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
|
||||
else:
|
||||
key, value = key_cache, value_cache
|
||||
cache = torch.stack([key_cache, value_cache], dim=3)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states, cache
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CausalAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
cross_attention_norm: Optional[str] = None,
|
||||
cross_attention_norm_num_groups: int = 32,
|
||||
qk_norm: Optional[str] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
spatial_norm_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
eps: float = 1e-5,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
_from_deprecated_attn_block: bool = False,
|
||||
processor: Optional["AttnProcessor2_0"] = None,
|
||||
out_dim: int = None,
|
||||
):
|
||||
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
|
||||
cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
|
||||
spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
|
||||
_from_deprecated_attn_block, processor, out_dim)
|
||||
processor = CausalAttnProcessor2_0()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
||||
**cross_attention_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
The forward method of the `Attention` class.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
The hidden states of the query.
|
||||
encoder_hidden_states (`torch.Tensor`, *optional*):
|
||||
The hidden states of the encoder.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
The attention mask to use. If `None`, no mask is applied.
|
||||
**cross_attention_kwargs:
|
||||
Additional keyword arguments to pass along to the cross attention.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The output of the attention layer.
|
||||
"""
|
||||
# The `Attention` class can call different attention processors / attention functions
|
||||
# here we simply pass along all tensors to the selected processor class
|
||||
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
||||
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
||||
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache=cache,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CausalBasicTransformerBlock(BasicTransformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
|
||||
cross_attention_dim, activation_fn, num_embeds_ada_norm,
|
||||
attention_bias, only_cross_attention, double_self_attention,
|
||||
upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
|
||||
self.attn1 = CausalAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 1. Self-Attention
|
||||
if self.use_ada_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
attn_output, cache = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
||||
cache=cache,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
||||
raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
|
||||
{self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
|
||||
|
||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||
ff_output = torch.cat(
|
||||
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
||||
dim=self._chunk_dim,
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states, cache
|
||||
|
||||
|
||||
class ConditionalDecoder(nn.Module):
|
||||
def __init__(
|
||||
@@ -640,7 +335,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CausalBasicTransformerBlock(
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
@@ -662,7 +357,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
||||
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CausalBasicTransformerBlock(
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
@@ -687,7 +382,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
||||
)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CausalBasicTransformerBlock(
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
@@ -724,7 +419,6 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t).to(t.dtype)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
@@ -740,36 +434,36 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
x, _, _ = resnet(x, mask_down, t)
|
||||
x = resnet(x, mask_down, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
||||
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x, _ = transformer_block(
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x, _ = downsample(x * mask_down)
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
x, _, _ = resnet(x, mask_mid, t)
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
||||
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x, _ = transformer_block(
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
@@ -780,124 +474,21 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
||||
mask_up = masks.pop()
|
||||
skip = hiddens.pop()
|
||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||
x, _, _ = resnet(x, mask_up, t)
|
||||
x = resnet(x, mask_up, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
||||
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x, _ = transformer_block(
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
x, _ = upsample(x * mask_up)
|
||||
x, _ = self.final_block(x, mask_up)
|
||||
x = upsample(x * mask_up)
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
return output * mask
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
|
||||
down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
||||
down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
||||
mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
||||
mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
||||
up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
||||
up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
||||
final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t).to(t.dtype)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
|
||||
if spks is not None:
|
||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, spks], "b * t")[0]
|
||||
if cond is not None:
|
||||
x = pack([x, cond], "b * t")[0]
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
|
||||
down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
|
||||
mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
|
||||
up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
|
||||
for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
|
||||
mask_down = masks[-1]
|
||||
x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
|
||||
resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for i, transformer_block in enumerate(transformer_blocks):
|
||||
x, down_blocks_kv_cache_new[index, i] = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
cache=down_blocks_kv_cache[index, i],
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
|
||||
x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
|
||||
resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for i, transformer_block in enumerate(transformer_blocks):
|
||||
x, mid_blocks_kv_cache_new[index, i] = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
cache=mid_blocks_kv_cache[index, i]
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
|
||||
for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
|
||||
mask_up = masks.pop()
|
||||
skip = hiddens.pop()
|
||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||
x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
|
||||
resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for i, transformer_block in enumerate(transformer_blocks):
|
||||
x, up_blocks_kv_cache_new[index, i] = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
cache=up_blocks_kv_cache[index, i]
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
|
||||
x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
|
||||
output = self.final_proj(x * mask_up)
|
||||
return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
|
||||
up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache
|
||||
|
||||
@@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
|
||||
mask = (~make_pad_mask(feat_len)).to(h)
|
||||
# NOTE this is unnecessary, feat/h already same shape
|
||||
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
@@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
# get conditions
|
||||
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
@@ -243,7 +241,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
cache,
|
||||
streaming,
|
||||
finalize):
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
@@ -257,16 +255,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
|
||||
# text encode
|
||||
if finalize is True:
|
||||
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
|
||||
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
||||
else:
|
||||
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
||||
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
|
||||
cache['encoder_cache']['offset'] = encoder_cache[0]
|
||||
cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
|
||||
cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
|
||||
cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
|
||||
cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
|
||||
cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
|
||||
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
@@ -276,14 +268,14 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat, cache['decoder_cache'] = self.decoder(
|
||||
feat, _ = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
cache=cache['decoder_cache']
|
||||
streaming=streaming
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), cache
|
||||
return feat.float(), None
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,10 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import threading
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.flow_matching import BASECFM
|
||||
from cosyvoice.utils.common import set_all_random_seed
|
||||
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
@@ -31,7 +32,6 @@ class ConditionalCFM(BASECFM):
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
||||
@@ -68,7 +68,7 @@ class ConditionalCFM(BASECFM):
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
@@ -109,7 +109,8 @@ class ConditionalCFM(BASECFM):
|
||||
x_in, mask_in,
|
||||
mu_in, t_in,
|
||||
spks_in,
|
||||
cond_in
|
||||
cond_in,
|
||||
streaming
|
||||
)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||
@@ -121,25 +122,33 @@ class ConditionalCFM(BASECFM):
|
||||
|
||||
return sol[-1].float()
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator(x, mask, mu, t, spks, cond)
|
||||
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
|
||||
# NOTE need to synchronize when switching stream
|
||||
torch.cuda.current_stream().synchronize()
|
||||
with stream:
|
||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('t', (2,))
|
||||
estimator.set_input_shape('spks', (2, 80))
|
||||
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
data_ptrs = [x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]) is True
|
||||
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.estimator.release_estimator(estimator, stream)
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
||||
@@ -187,10 +196,11 @@ class ConditionalCFM(BASECFM):
|
||||
class CausalConditionalCFM(ConditionalCFM):
|
||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
||||
set_all_random_seed(0)
|
||||
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
@@ -209,131 +219,9 @@ class CausalConditionalCFM(ConditionalCFM):
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
offset = cache.pop('offset')
|
||||
z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
|
||||
z = z[:, :, offset:]
|
||||
offset += mu.size(2)
|
||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
||||
# fix prompt and overlap part mu and z
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
|
||||
cache['offset'] = offset
|
||||
return mel, cache
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
t = t.unsqueeze(dim=0)
|
||||
|
||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||
# Or in future might add like a return_all_steps flag
|
||||
sol = []
|
||||
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
flow_cache_size = cache['down_blocks_kv_cache'].shape[4]
|
||||
for step in range(1, len(t_span)):
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
x_in[:] = x
|
||||
mask_in[:] = mask
|
||||
mu_in[0] = mu
|
||||
t_in[:] = t.unsqueeze(0)
|
||||
spks_in[0] = spks
|
||||
cond_in[0] = cond
|
||||
cache_step = {k: v[step - 1] for k, v in cache.items()}
|
||||
dphi_dt, cache_step = self.forward_estimator(
|
||||
x_in, mask_in,
|
||||
mu_in, t_in,
|
||||
spks_in,
|
||||
cond_in,
|
||||
cache_step
|
||||
)
|
||||
# NOTE if smaller than flow_cache_size, means last chunk, no need to cache
|
||||
if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
|
||||
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
|
||||
cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
|
||||
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
|
||||
cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
|
||||
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
|
||||
cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
|
||||
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t_span[step + 1] - t
|
||||
return sol[-1].float(), cache
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
|
||||
cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
|
||||
self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
|
||||
self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
|
||||
self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
|
||||
self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
|
||||
self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
|
||||
self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
|
||||
# run trt engine
|
||||
down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
||||
mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
|
||||
up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
||||
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
cache['down_blocks_conv_cache'].contiguous().data_ptr(),
|
||||
cache['down_blocks_kv_cache'].contiguous().data_ptr(),
|
||||
cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
|
||||
cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
|
||||
cache['up_blocks_conv_cache'].contiguous().data_ptr(),
|
||||
cache['up_blocks_kv_cache'].contiguous().data_ptr(),
|
||||
cache['final_blocks_conv_cache'].contiguous().data_ptr(),
|
||||
x.data_ptr(),
|
||||
cache['down_blocks_conv_cache'].data_ptr(),
|
||||
down_blocks_kv_cache_out.data_ptr(),
|
||||
cache['mid_blocks_conv_cache'].data_ptr(),
|
||||
mid_blocks_kv_cache_out.data_ptr(),
|
||||
cache['up_blocks_conv_cache'].data_ptr(),
|
||||
up_blocks_kv_cache_out.data_ptr(),
|
||||
cache['final_blocks_conv_cache'].data_ptr()]) is True
|
||||
cache = (cache['down_blocks_conv_cache'],
|
||||
down_blocks_kv_cache_out,
|
||||
cache['mid_blocks_conv_cache'],
|
||||
mid_blocks_kv_cache_out,
|
||||
cache['up_blocks_conv_cache'],
|
||||
up_blocks_kv_cache_out,
|
||||
cache['final_blocks_conv_cache'])
|
||||
return x, cache
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
|
||||
|
||||
@@ -223,6 +223,172 @@ class SourceModuleHnNSF(torch.nn.Module):
|
||||
return sine_merge, noise, uv
|
||||
|
||||
|
||||
class SineGen2(torch.nn.Module):
|
||||
""" Definition of sine generator
|
||||
SineGen(samp_rate, harmonic_num = 0,
|
||||
sine_amp = 0.1, noise_std = 0.003,
|
||||
voiced_threshold = 0,
|
||||
flag_for_pulse=False)
|
||||
samp_rate: sampling rate in Hz
|
||||
harmonic_num: number of harmonic overtones (default 0)
|
||||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
noise_std: std of Gaussian noise (default 0.003)
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||
sine_amp=0.1, noise_std=0.003,
|
||||
voiced_threshold=0,
|
||||
flag_for_pulse=False):
|
||||
super(SineGen2, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
self.harmonic_num = harmonic_num
|
||||
self.dim = self.harmonic_num + 1
|
||||
self.sampling_rate = samp_rate
|
||||
self.voiced_threshold = voiced_threshold
|
||||
self.flag_for_pulse = flag_for_pulse
|
||||
self.upsample_scale = upsample_scale
|
||||
|
||||
def _f02uv(self, f0):
|
||||
# generate uv signal
|
||||
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||
return uv
|
||||
|
||||
def _f02sine(self, f0_values):
|
||||
""" f0_values: (batchsize, length, dim)
|
||||
where dim indicates fundamental tone and overtones
|
||||
"""
|
||||
# convert to F0 in rad. The interger part n can be ignored
|
||||
# because 2 * np.pi * n doesn't affect phase
|
||||
rad_values = (f0_values / self.sampling_rate) % 1
|
||||
|
||||
# initial phase noise (no noise for fundamental component)
|
||||
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
||||
rand_ini[:, 0] = 0
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
|
||||
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||
if not self.flag_for_pulse:
|
||||
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
||||
scale_factor=1 / self.upsample_scale,
|
||||
mode="linear").transpose(1, 2)
|
||||
|
||||
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
||||
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
||||
sines = torch.sin(phase)
|
||||
else:
|
||||
# If necessary, make sure that the first time step of every
|
||||
# voiced segments is sin(pi) or cos(0)
|
||||
# This is used for pulse-train generation
|
||||
|
||||
# identify the last time step in unvoiced segments
|
||||
uv = self._f02uv(f0_values)
|
||||
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||
uv_1[:, -1, :] = 1
|
||||
u_loc = (uv < 1) * (uv_1 > 0)
|
||||
|
||||
# get the instantanouse phase
|
||||
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||
# different batch needs to be processed differently
|
||||
for idx in range(f0_values.shape[0]):
|
||||
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||
# stores the accumulation of i.phase within
|
||||
# each voiced segments
|
||||
tmp_cumsum[idx, :, :] = 0
|
||||
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||
|
||||
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||
# within the previous voiced segment.
|
||||
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||
|
||||
# get the sines
|
||||
sines = torch.cos(i_phase * 2 * np.pi)
|
||||
return sines
|
||||
|
||||
def forward(self, f0):
|
||||
""" sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, length, dim=1)
|
||||
f0 for unvoiced steps should be 0
|
||||
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||
output uv: tensor(batchsize=1, length, 1)
|
||||
"""
|
||||
# fundamental component
|
||||
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||
|
||||
# generate sine waveforms
|
||||
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||
|
||||
# generate uv signal
|
||||
uv = self._f02uv(f0)
|
||||
|
||||
# noise: for unvoiced should be similar to sine_amp
|
||||
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||
# . for voiced regions is self.noise_std
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
|
||||
# first: set the unvoiced part to 0 by uv
|
||||
# then: additive noise
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves, uv, noise
|
||||
|
||||
|
||||
class SourceModuleHnNSF2(torch.nn.Module):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
sampling_rate: sampling_rate in Hz
|
||||
harmonic_num: number of harmonic above F0 (default: 0)
|
||||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
note that amplitude of noise in unvoiced is decided
|
||||
by sine_amp
|
||||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
uv (batchsize, length, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0):
|
||||
super(SourceModuleHnNSF2, self).__init__()
|
||||
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
|
||||
# to produce sine waveforms
|
||||
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
|
||||
sine_amp, add_noise_std, voiced_threshod)
|
||||
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
"""
|
||||
# source for harmonic branch
|
||||
with torch.no_grad():
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
|
||||
# source for noise branch, in the same shape as uv
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
|
||||
|
||||
class HiFTGenerator(nn.Module):
|
||||
"""
|
||||
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
||||
@@ -259,7 +425,9 @@ class HiFTGenerator(nn.Module):
|
||||
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
# NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
|
||||
this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
|
||||
self.m_source = this_SourceModuleHnNSF(
|
||||
sampling_rate=sampling_rate,
|
||||
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
||||
harmonic_num=nb_harmonics,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,7 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import queue
|
||||
import random
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Optional, Callable, List, Generator
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -170,6 +174,7 @@ class TransformerLM(torch.nn.Module):
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
uuid: str = '',
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
@@ -270,7 +275,6 @@ class Qwen2LM(TransformerLM):
|
||||
self.llm_input_size = llm_input_size
|
||||
self.llm_output_size = llm_output_size
|
||||
self.speech_token_size = speech_token_size
|
||||
|
||||
# 2. build speech token language model related modules
|
||||
self.sos_eos = 0
|
||||
self.task_id = 1
|
||||
@@ -293,6 +297,10 @@ class Qwen2LM(TransformerLM):
|
||||
self.sampling = sampling
|
||||
self.mix_ratio = mix_ratio
|
||||
|
||||
# 5. vllm related
|
||||
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
||||
self.vllm_output_queue = {}
|
||||
|
||||
def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
|
||||
lm_target, lm_input = [], []
|
||||
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||
@@ -369,6 +377,53 @@ class Qwen2LM(TransformerLM):
|
||||
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
||||
return {'loss': loss, 'acc': acc}
|
||||
|
||||
def forward_dpo(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
text_token = batch['text_token'].to(device)
|
||||
text_token_len = batch['text_token_len'].to(device)
|
||||
speech_token = batch['speech_token'].to(device)
|
||||
speech_token_len = batch['speech_token_len'].to(device)
|
||||
reject_speech_token = batch['reject_speech_token'].to(device)
|
||||
reject_speech_token_len = batch['reject_speech_token_len'].to(device)
|
||||
|
||||
# 1. encode text_token
|
||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||
|
||||
# 2. encode speech_token
|
||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
|
||||
speech_token_combined = speech_token + reject_speech_token
|
||||
speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
|
||||
speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
|
||||
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
|
||||
|
||||
# 3. prepare llm_input/target
|
||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
|
||||
speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
|
||||
lm_target = lm_target.to(device)
|
||||
|
||||
# 4. run lm forward
|
||||
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||
logits = self.llm_decoder(lm_output)
|
||||
chosen_logits = logits[:text_token.shape[0]]
|
||||
rejected_logits = logits[text_token.shape[0]:]
|
||||
chosen_lm_target = lm_target[:text_token.shape[0]]
|
||||
rejected_lm_target = lm_target[text_token.shape[0]:]
|
||||
loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
|
||||
acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
|
||||
|
||||
# 5. calculate dpo logits
|
||||
chosen_lm_mask = chosen_lm_target == IGNORE_ID
|
||||
rejected_lm_mask = rejected_lm_target == IGNORE_ID
|
||||
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1)
|
||||
rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
|
||||
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
@@ -382,6 +437,7 @@ class Qwen2LM(TransformerLM):
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
uuid: str = '',
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
@@ -402,22 +458,57 @@ class Qwen2LM(TransformerLM):
|
||||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||
|
||||
# 5. step by step decode
|
||||
out_tokens = []
|
||||
cache = None
|
||||
for i in range(max_len):
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
if top_ids > self.speech_token_size:
|
||||
continue
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
|
||||
yield token
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
|
||||
if hasattr(self, 'vllm'):
|
||||
from vllm import SamplingParams, RequestOutput
|
||||
sampling_params = SamplingParams(top_k=sampling,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
min_tokens=min_len,
|
||||
max_tokens=max_len)
|
||||
with self.lock:
|
||||
self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
|
||||
self.vllm_output_queue[uuid] = queue.Queue()
|
||||
out_tokens = []
|
||||
while True:
|
||||
with self.lock:
|
||||
if self.vllm_output_queue[uuid].empty() is True:
|
||||
request_outputs: List[RequestOutput] = self.vllm.step()
|
||||
for request_output in request_outputs:
|
||||
top_ids = list(request_output.outputs[0].token_ids)[-1]
|
||||
self.vllm_output_queue[request_output.request_id].put(top_ids)
|
||||
if self.vllm_output_queue[uuid].empty() is False:
|
||||
top_ids = self.vllm_output_queue[uuid].get()
|
||||
if top_ids in self.stop_token_ids:
|
||||
break
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
if len(out_tokens) == max_len:
|
||||
break
|
||||
time.sleep(0.001)
|
||||
with self.lock:
|
||||
self.vllm_output_queue.pop(uuid)
|
||||
else:
|
||||
out_tokens = []
|
||||
cache = None
|
||||
for i in range(max_len):
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
if top_ids > self.speech_token_size:
|
||||
continue
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_bistream(
|
||||
|
||||
@@ -56,16 +56,11 @@ class Upsample1D(nn.Module):
|
||||
# In this mode, first repeat interpolate, than conv with stride=1
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
||||
if conv_cache.size(2) == 0:
|
||||
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
||||
else:
|
||||
assert conv_cache.size(2) == self.stride * 2
|
||||
outputs = torch.concat([conv_cache, outputs], dim=2)
|
||||
conv_cache_new = outputs[:, :, -self.stride * 2:]
|
||||
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
||||
outputs = self.conv(outputs)
|
||||
return outputs, input_lengths * self.stride, conv_cache_new
|
||||
return outputs, input_lengths * self.stride
|
||||
|
||||
|
||||
class PreLookaheadLayer(nn.Module):
|
||||
@@ -83,7 +78,7 @@ class PreLookaheadLayer(nn.Module):
|
||||
kernel_size=3, stride=1, padding=0,
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
|
||||
"""
|
||||
inputs: (batch_size, seq_len, channels)
|
||||
"""
|
||||
@@ -93,22 +88,18 @@ class PreLookaheadLayer(nn.Module):
|
||||
if context.size(2) == 0:
|
||||
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
||||
else:
|
||||
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
||||
assert context.size(2) == self.pre_lookahead_len
|
||||
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
|
||||
outputs = F.leaky_relu(self.conv1(outputs))
|
||||
# outputs
|
||||
if conv2_cache.size(2) == 0:
|
||||
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
||||
else:
|
||||
assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1
|
||||
outputs = torch.concat([conv2_cache, outputs], dim=2)
|
||||
conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):]
|
||||
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = outputs.transpose(1, 2).contiguous()
|
||||
|
||||
# residual connection
|
||||
outputs = outputs + inputs
|
||||
return outputs, conv2_cache_new
|
||||
return outputs
|
||||
|
||||
|
||||
class UpsampleConformerEncoder(torch.nn.Module):
|
||||
@@ -253,6 +244,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
xs_lens: torch.Tensor,
|
||||
context: torch.Tensor = torch.zeros(0, 0, 0),
|
||||
decoding_chunk_size: int = 0,
|
||||
num_decoding_left_chunks: int = -1,
|
||||
streaming: bool = False,
|
||||
@@ -285,15 +277,19 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
||||
if self.global_cmvn is not None:
|
||||
xs = self.global_cmvn(xs)
|
||||
xs, pos_emb, masks = self.embed(xs, masks)
|
||||
if context.size(1) != 0:
|
||||
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
||||
context_masks = torch.ones(1, 1, context.size(1)).to(masks)
|
||||
context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
|
||||
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
||||
# lookahead + conformer encoder
|
||||
xs, _ = self.pre_lookahead_layer(xs)
|
||||
xs = self.pre_lookahead_layer(xs, context=context)
|
||||
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||
|
||||
# upsample + conformer encoder
|
||||
xs = xs.transpose(1, 2).contiguous()
|
||||
xs, xs_lens, _ = self.up_layer(xs, xs_lens)
|
||||
xs, xs_lens = self.up_layer(xs, xs_lens)
|
||||
xs = xs.transpose(1, 2).contiguous()
|
||||
T = xs.size(1)
|
||||
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
||||
@@ -322,100 +318,3 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
||||
for layer in self.up_encoders:
|
||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||
return xs
|
||||
|
||||
@torch.jit.export
|
||||
def forward_chunk(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
xs_lens: torch.Tensor,
|
||||
offset: int = 0,
|
||||
context: torch.Tensor = torch.zeros(0, 0, 0),
|
||||
pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0),
|
||||
encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0),
|
||||
upsample_offset: int = 0,
|
||||
upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0),
|
||||
upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs: padded input tensor (B, T, D)
|
||||
xs_lens: input length (B)
|
||||
decoding_chunk_size: decoding chunk size for dynamic chunk
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
Returns:
|
||||
encoder output tensor xs, and subsampled masks
|
||||
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
||||
masks: torch.Tensor batch padding mask after subsample
|
||||
(B, 1, T' ~= T/subsample_rate)
|
||||
NOTE(xcsong):
|
||||
We pass the `__call__` method of the modules instead of `forward` to the
|
||||
checkpointing API because `__call__` attaches all the hooks of the module.
|
||||
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||
"""
|
||||
assert xs.size(0) == 1
|
||||
# tmp_masks is just for interface compatibility
|
||||
tmp_masks = torch.ones(1,
|
||||
xs.size(1),
|
||||
device=xs.device,
|
||||
dtype=torch.bool)
|
||||
tmp_masks = tmp_masks.unsqueeze(1)
|
||||
if self.global_cmvn is not None:
|
||||
xs = self.global_cmvn(xs)
|
||||
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
|
||||
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
|
||||
offset += xs.size(1)
|
||||
tmp_masks = torch.ones(1,
|
||||
context.size(1),
|
||||
device=context.device,
|
||||
dtype=torch.bool)
|
||||
tmp_masks = tmp_masks.unsqueeze(1)
|
||||
if context.size(1) != 0:
|
||||
context, _, _ = self.embed(context, tmp_masks, offset)
|
||||
|
||||
# lookahead + conformer encoder
|
||||
xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache)
|
||||
# NOTE in cache mode we do not need to call add_optional_chunk_mask
|
||||
chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device)
|
||||
mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
|
||||
encoders_kv_cache_list = []
|
||||
for index, layer in enumerate(self.encoders):
|
||||
xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
|
||||
encoders_kv_cache_list.append(encoders_kv_cache_new)
|
||||
encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
|
||||
|
||||
# upsample
|
||||
xs = xs.transpose(1, 2).contiguous()
|
||||
xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache)
|
||||
xs = xs.transpose(1, 2).contiguous()
|
||||
|
||||
# tmp_masks is just for interface compatibility
|
||||
tmp_masks = torch.ones(1,
|
||||
xs.size(1),
|
||||
device=xs.device,
|
||||
dtype=torch.bool)
|
||||
tmp_masks = tmp_masks.unsqueeze(1)
|
||||
xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset)
|
||||
upsample_offset += xs.size(1)
|
||||
|
||||
# conformer encoder
|
||||
chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device)
|
||||
mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
|
||||
upsample_kv_cache_list = []
|
||||
for index, layer in enumerate(self.up_encoders):
|
||||
xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index])
|
||||
upsample_kv_cache_list.append(upsample_kv_cache_new)
|
||||
upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0)
|
||||
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
# Here we assume the mask is not changed in encoder layers, so just
|
||||
# return the masks before encoder layers, and the masks will be used
|
||||
# for cross attention with decoder later
|
||||
return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,6 +16,7 @@
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""Unility functions for Transformer."""
|
||||
|
||||
import queue
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
@@ -164,3 +166,21 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
||||
mask = (1.0 - mask) * -1.0e+10
|
||||
return mask
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
self.trt_engine = trt_engine
|
||||
for _ in range(trt_concurrent):
|
||||
trt_context = trt_engine.create_execution_context()
|
||||
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
|
||||
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||
self.trt_context_pool.put([trt_context, trt_stream])
|
||||
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.trt_context_pool.get(), self.trt_engine
|
||||
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
|
||||
@@ -25,14 +25,16 @@ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, l
|
||||
|
||||
class Executor:
|
||||
|
||||
def __init__(self, gan: bool = False):
|
||||
def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
|
||||
self.gan = gan
|
||||
self.ref_model = ref_model
|
||||
self.dpo_loss = dpo_loss
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.rank = int(os.environ.get('RANK', 0))
|
||||
self.device = torch.device('cuda:{}'.format(self.rank))
|
||||
|
||||
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
|
||||
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
@@ -44,6 +46,8 @@ class Executor:
|
||||
# torch.nn.parallel.DistributedDataParallel to be able to train
|
||||
# with uneven inputs across participating processes.
|
||||
model.train()
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
||||
with model_context():
|
||||
for batch_idx, batch_dict in enumerate(train_data_loader):
|
||||
@@ -65,7 +69,7 @@ class Executor:
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,7 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import torchaudio
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
@@ -56,7 +59,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
profile = builder.create_optimization_profile()
|
||||
@@ -83,3 +86,44 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
with open(trt_model, "wb") as f:
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
def export_cosyvoice2_vllm(model, model_path, device):
|
||||
if os.path.exists(model_path):
|
||||
return
|
||||
pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
vocab_size = model.speech_embedding.num_embeddings
|
||||
feature_size = model.speech_embedding.embedding_dim
|
||||
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
dtype = torch.bfloat16
|
||||
# lm_head
|
||||
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
|
||||
with torch.no_grad():
|
||||
new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
|
||||
new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
|
||||
new_lm_head.weight[vocab_size:] = 0
|
||||
new_lm_head.bias[vocab_size:] = 0
|
||||
model.llm.model.lm_head = new_lm_head
|
||||
new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
||||
# embed_tokens
|
||||
embed_tokens = model.llm.model.model.embed_tokens
|
||||
with torch.no_grad():
|
||||
new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
|
||||
new_codec_embed.weight[vocab_size:] = 0
|
||||
model.llm.model.set_input_embeddings(new_codec_embed)
|
||||
model.llm.model.to(device)
|
||||
model.llm.model.to(dtype)
|
||||
tmp_vocab_size = model.llm.model.config.vocab_size
|
||||
tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
|
||||
del model.llm.model.generation_config.eos_token_id
|
||||
del model.llm.model.config.bos_token_id
|
||||
del model.llm.model.config.eos_token_id
|
||||
model.llm.model.config.vocab_size = pad_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = False
|
||||
model.llm.model.config.use_bias = True
|
||||
model.llm.model.save_pretrained(model_path)
|
||||
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||
model.llm.model.config.vocab_size = tmp_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||
model.llm.model.set_input_embeddings(embed_tokens)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
|
||||
@@ -18,3 +19,39 @@ def mel_loss(real_speech, generated_speech, mel_transforms):
|
||||
mel_g = transform(generated_speech)
|
||||
loss += F.l1_loss(mel_g, mel_r)
|
||||
return loss
|
||||
|
||||
|
||||
class DPOLoss(torch.nn.Module):
|
||||
"""
|
||||
DPO Loss
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.label_smoothing = label_smoothing
|
||||
self.ipo = ipo
|
||||
|
||||
def forward(
|
||||
self,
|
||||
policy_chosen_logps: torch.Tensor,
|
||||
policy_rejected_logps: torch.Tensor,
|
||||
reference_chosen_logps: torch.Tensor,
|
||||
reference_rejected_logps: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||
logits = pi_logratios - ref_logratios
|
||||
if self.ipo:
|
||||
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
|
||||
else:
|
||||
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
loss = losses.mean()
|
||||
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
|
||||
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
|
||||
|
||||
return loss, chosen_rewards, rejected_rewards
|
||||
|
||||
@@ -86,7 +86,7 @@ def subsequent_mask(
|
||||
return mask
|
||||
|
||||
|
||||
def subsequent_chunk_mask(
|
||||
def subsequent_chunk_mask_deprecated(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
@@ -124,6 +124,40 @@ def subsequent_chunk_mask(
|
||||
return ret
|
||||
|
||||
|
||||
def subsequent_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
||||
pos_idx = torch.arange(size, device=device)
|
||||
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
||||
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
||||
return ret
|
||||
|
||||
|
||||
def add_optional_chunk_mask(xs: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
use_dynamic_chunk: bool,
|
||||
|
||||
@@ -50,10 +50,10 @@ def init_distributed(args):
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def init_dataset_and_dataloader(args, configs, gan):
|
||||
def init_dataset_and_dataloader(args, configs, gan, dpo):
|
||||
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
|
||||
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True)
|
||||
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False)
|
||||
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
|
||||
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=False, partition=False)
|
||||
|
||||
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
|
||||
train_data_loader = DataLoader(train_dataset,
|
||||
@@ -235,7 +235,7 @@ def cosyvoice_join(group_join, info_dict):
|
||||
return False
|
||||
|
||||
|
||||
def batch_forward(model, batch, scaler, info_dict):
|
||||
def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
|
||||
device = int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
dtype = info_dict["dtype"]
|
||||
@@ -253,6 +253,24 @@ def batch_forward(model, batch, scaler, info_dict):
|
||||
|
||||
with autocast:
|
||||
info_dict['loss_dict'] = model(batch, device)
|
||||
if ref_model is not None and dpo_loss is not None:
|
||||
chosen_logps = info_dict['loss_dict']["chosen_logps"]
|
||||
rejected_logps = info_dict['loss_dict']["rejected_logps"]
|
||||
sft_loss = info_dict['loss_dict']['loss']
|
||||
with torch.no_grad():
|
||||
ref_loss_dict = ref_model(batch, device)
|
||||
reference_chosen_logps = ref_loss_dict["chosen_logps"]
|
||||
reference_rejected_logps = ref_loss_dict["rejected_logps"]
|
||||
preference_loss, chosen_reward, reject_reward = dpo_loss(
|
||||
chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
dpo_acc = (chosen_reward > reject_reward).float().mean()
|
||||
info_dict['loss_dict']["loss"] = preference_loss + sft_loss
|
||||
info_dict['loss_dict']["sft_loss"] = sft_loss
|
||||
info_dict['loss_dict']["dpo_loss"] = preference_loss
|
||||
info_dict['loss_dict']["dpo_acc"] = dpo_acc
|
||||
info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean()
|
||||
info_dict['loss_dict']["reject_reward"] = reject_reward.mean()
|
||||
return info_dict
|
||||
|
||||
|
||||
|
||||
103
cosyvoice/vllm/cosyvoice2.py
Normal file
103
cosyvoice/vllm/cosyvoice2.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from vllm.model_executor.models.qwen2 import *
|
||||
|
||||
|
||||
class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
@@ -5,70 +5,40 @@ __set_seed3: !apply:torch.manual_seed [1986]
|
||||
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
||||
|
||||
# fixed params
|
||||
sample_rate: 22050
|
||||
text_encoder_input_size: 512
|
||||
llm_input_size: 1024
|
||||
llm_output_size: 1024
|
||||
sample_rate: 24000 # 16000 for llm, 24000 for cfm
|
||||
llm_input_size: 896
|
||||
llm_output_size: 896
|
||||
spk_embed_dim: 192
|
||||
qwen_pretrain_path: 'CosyVoice2-0.5B/CosyVoice-BlankEN'
|
||||
|
||||
# model params
|
||||
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
||||
# for system/third_party class/function, we do not require this.
|
||||
llm: !new:cosyvoice.llm.llm.TransformerLM
|
||||
text_encoder_input_size: !ref <text_encoder_input_size>
|
||||
llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM
|
||||
llm_input_size: !ref <llm_input_size>
|
||||
llm_output_size: !ref <llm_output_size>
|
||||
text_token_size: 51866 # change to 60515 if you want to train with CosyVoice-300M-25Hz recipe
|
||||
speech_token_size: 4096
|
||||
speech_token_size: 6561
|
||||
length_normalized_loss: True
|
||||
lsm_weight: 0
|
||||
spk_embed_dim: !ref <spk_embed_dim>
|
||||
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
||||
input_size: !ref <text_encoder_input_size>
|
||||
output_size: 1024
|
||||
attention_heads: 16
|
||||
linear_units: 4096
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
normalize_before: True
|
||||
input_layer: 'linear'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
use_cnn_module: False
|
||||
macaron_style: False
|
||||
use_dynamic_chunk: False
|
||||
use_dynamic_left_chunk: False
|
||||
static_chunk_size: 1
|
||||
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
|
||||
input_size: !ref <llm_input_size>
|
||||
output_size: !ref <llm_output_size>
|
||||
attention_heads: 16
|
||||
linear_units: 4096
|
||||
num_blocks: 14
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: 'linear_legacy'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
static_chunk_size: 1
|
||||
dpo: True
|
||||
llm: !new:cosyvoice.llm.llm.Qwen2Encoder
|
||||
pretrain_path: !ref <qwen_pretrain_path>
|
||||
sampling: !name:cosyvoice.utils.common.ras_sampling
|
||||
top_p: 0.8
|
||||
top_k: 25
|
||||
win_size: 10
|
||||
tau_r: 0.1
|
||||
|
||||
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
||||
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
|
||||
input_size: 512
|
||||
output_size: 80
|
||||
spk_embed_dim: !ref <spk_embed_dim>
|
||||
output_type: 'mel'
|
||||
vocab_size: 4096
|
||||
input_frame_rate: 50 # change to 25 if you want to train with CosyVoice-300M-25Hz recipe
|
||||
vocab_size: 6561
|
||||
input_frame_rate: 25
|
||||
only_mask_loss: True
|
||||
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
||||
token_mel_ratio: 2
|
||||
pre_lookahead_len: 3
|
||||
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
|
||||
output_size: 512
|
||||
attention_heads: 8
|
||||
linear_units: 2048
|
||||
@@ -83,10 +53,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
||||
input_size: 512
|
||||
use_cnn_module: False
|
||||
macaron_style: False
|
||||
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
|
||||
channels: 80
|
||||
sampling_ratios: [1, 1, 1, 1]
|
||||
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
|
||||
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
|
||||
in_channels: 240
|
||||
n_spks: 1
|
||||
spk_emb_dim: 80
|
||||
@@ -101,7 +68,8 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
||||
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
|
||||
in_channels: 320
|
||||
out_channels: 80
|
||||
channels: [256, 256]
|
||||
causal: True
|
||||
channels: [256]
|
||||
dropout: 0.0
|
||||
attention_head_dim: 64
|
||||
n_blocks: 4
|
||||
@@ -117,15 +85,15 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
||||
nsf_alpha: 0.1
|
||||
nsf_sigma: 0.003
|
||||
nsf_voiced_threshold: 10
|
||||
upsample_rates: [8, 8]
|
||||
upsample_kernel_sizes: [16, 16]
|
||||
upsample_rates: [8, 5, 3]
|
||||
upsample_kernel_sizes: [16, 11, 7]
|
||||
istft_params:
|
||||
n_fft: 16
|
||||
hop_len: 4
|
||||
resblock_kernel_sizes: [3, 7, 11]
|
||||
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
source_resblock_kernel_sizes: [7, 11]
|
||||
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
|
||||
source_resblock_kernel_sizes: [7, 7, 11]
|
||||
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
lrelu_slope: 0.1
|
||||
audio_limit: 0.99
|
||||
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
||||
@@ -133,6 +101,25 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
||||
in_channels: 80
|
||||
cond_channels: 512
|
||||
|
||||
# gan related module
|
||||
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
|
||||
n_fft: 1024
|
||||
num_mels: 80
|
||||
sampling_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
win_size: 1024
|
||||
fmin: 0
|
||||
fmax: null
|
||||
center: False
|
||||
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
|
||||
generator: !ref <hift>
|
||||
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
|
||||
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
|
||||
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
|
||||
mel_spec_transform: [
|
||||
!ref <mel_spec_transform1>
|
||||
]
|
||||
|
||||
# processor functions
|
||||
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
||||
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
|
||||
@@ -151,6 +138,8 @@ filter: !name:cosyvoice.dataset.processor.filter
|
||||
token_min_length: 1
|
||||
resample: !name:cosyvoice.dataset.processor.resample
|
||||
resample_rate: !ref <sample_rate>
|
||||
truncate: !name:cosyvoice.dataset.processor.truncate
|
||||
truncate_length: 24576 # must be a multiplier of hop_size
|
||||
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||
n_fft: 1024
|
||||
num_mels: 80
|
||||
@@ -162,6 +151,9 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||
center: False
|
||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
||||
feat_extractor: !ref <feat_extractor>
|
||||
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
||||
sample_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
||||
normalize: True
|
||||
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
||||
@@ -170,9 +162,10 @@ sort: !name:cosyvoice.dataset.processor.sort
|
||||
sort_size: 500 # sort_size should be less than shuffle_size
|
||||
batch: !name:cosyvoice.dataset.processor.batch
|
||||
batch_type: 'dynamic'
|
||||
max_frames_in_batch: 2000
|
||||
max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g
|
||||
padding: !name:cosyvoice.dataset.processor.padding
|
||||
use_spk_embedding: False # change to True during sft
|
||||
use_spk_embedding: True # change to True during sft
|
||||
dpo: True
|
||||
|
||||
# dataset processor pipeline
|
||||
data_pipeline: [
|
||||
@@ -187,17 +180,47 @@ data_pipeline: [
|
||||
!ref <batch>,
|
||||
!ref <padding>,
|
||||
]
|
||||
data_pipeline_gan: [
|
||||
!ref <parquet_opener>,
|
||||
!ref <tokenize>,
|
||||
!ref <filter>,
|
||||
!ref <resample>,
|
||||
!ref <truncate>,
|
||||
!ref <compute_fbank>,
|
||||
!ref <compute_f0>,
|
||||
!ref <parse_embedding>,
|
||||
!ref <shuffle>,
|
||||
!ref <sort>,
|
||||
!ref <batch>,
|
||||
!ref <padding>,
|
||||
]
|
||||
|
||||
# train conf
|
||||
# llm flow train conf
|
||||
train_conf:
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001 # change to 1e-5 during sft
|
||||
lr: 0.00001 # change to 1e-5 during sft
|
||||
scheduler: warmuplr # change to constantlr during sft
|
||||
scheduler_conf:
|
||||
warmup_steps: 2500
|
||||
warmup_steps: 25000
|
||||
max_epoch: 200
|
||||
grad_clip: 5
|
||||
accum_grad: 2
|
||||
log_interval: 100
|
||||
save_per_step: -1
|
||||
|
||||
# gan train conf
|
||||
train_conf_gan:
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.0002 # use small lr for gan training
|
||||
scheduler: constantlr
|
||||
optim_d: adam
|
||||
optim_conf_d:
|
||||
lr: 0.0002 # use small lr for gan training
|
||||
scheduler_d: constantlr
|
||||
max_epoch: 200
|
||||
grad_clip: 5
|
||||
accum_grad: 1 # in gan training, accum_grad must be 1
|
||||
log_interval: 100
|
||||
save_per_step: -1
|
||||
@@ -49,5 +49,7 @@ if __name__ == "__main__":
|
||||
type=str)
|
||||
parser.add_argument('--des_dir',
|
||||
type=str)
|
||||
parser.add_argument('--ref_model',
|
||||
type=str)
|
||||
args = parser.parse_args()
|
||||
main()
|
||||
|
||||
50
examples/libritts/cosyvoice/local/prepare_reject_sample.py
Normal file
50
examples/libritts/cosyvoice/local/prepare_reject_sample.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torchaudio
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def main():
|
||||
cosyvoice = CosyVoice2(args.ref_model)
|
||||
|
||||
utt2wav, utt2text = {}, {}
|
||||
with open('{}/wav.scp'.format(args.src_dir)) as f:
|
||||
for l in f:
|
||||
l = l.split('\n')[0].split()
|
||||
utt2wav[l[0]] = l[1]
|
||||
with open('{}/text'.format(args.src_dir)) as f:
|
||||
for l in f:
|
||||
l = l.split('\n')[0].split()
|
||||
utt2text[l[0]] = ' '.join(l[1:])
|
||||
|
||||
os.makedirs('{}/wav'.format(args.des_dir), exist_ok=True)
|
||||
with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
|
||||
for utt, wav in tqdm(utt2wav.items()):
|
||||
prompt_speech_16k = load_wav(wav, 16000)
|
||||
if prompt_speech_16k.shape[1] >= 30 * 16000:
|
||||
continue
|
||||
speech_list = []
|
||||
for _, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)):
|
||||
speech_list.append(j['tts_speech'])
|
||||
negative_wav = os.path.abspath('{}/wav/{}'.format(args.des_dir, os.path.basename(wav)))
|
||||
torchaudio.save(negative_wav, torch.concat(speech_list, dim=1), cosyvoice.sample_rate, backend='soundfile')
|
||||
f.write('{} {}\n'.format(utt, negative_wav))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--src_dir',
|
||||
type=str)
|
||||
parser.add_argument('--des_dir',
|
||||
type=str)
|
||||
parser.add_argument('--ref_model',
|
||||
type=str)
|
||||
args = parser.parse_args()
|
||||
main()
|
||||
@@ -51,23 +51,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
# inference
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
|
||||
for mode in sft zero_shot; do
|
||||
python cosyvoice/bin/inference.py --mode $mode \
|
||||
--gpu 0 \
|
||||
--config conf/cosyvoice.yaml \
|
||||
--prompt_data data/test-clean/parquet/data.list \
|
||||
--prompt_utt2data data/test-clean/parquet/utt2data.list \
|
||||
--tts_text `pwd`/tts_text.json \
|
||||
--llm_model $pretrained_model_dir/llm.pt \
|
||||
--flow_model $pretrained_model_dir/flow.pt \
|
||||
--hifigan_model $pretrained_model_dir/hift.pt \
|
||||
--result_dir `pwd`/exp/cosyvoice/test-clean/$mode
|
||||
done
|
||||
fi
|
||||
|
||||
# train llm
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
|
||||
@@ -15,7 +15,7 @@ token_mel_ratio: 2
|
||||
|
||||
# stream related params
|
||||
chunk_size: 25 # streaming inference chunk size, in token
|
||||
num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
|
||||
num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
|
||||
|
||||
# model params
|
||||
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
||||
@@ -158,6 +158,7 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||
center: False
|
||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
||||
feat_extractor: !ref <feat_extractor>
|
||||
token_mel_ratio: 2
|
||||
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
||||
sample_rate: !ref <sample_rate>
|
||||
hop_size: 480
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH
|
||||
1
examples/libritts/cosyvoice2/path.sh
Symbolic link
1
examples/libritts/cosyvoice2/path.sh
Symbolic link
@@ -0,0 +1 @@
|
||||
../cosyvoice/path.sh
|
||||
@@ -51,25 +51,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
# inference
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
|
||||
# TODO consider remove bin/inference.py, or use similar initilization method as in readme
|
||||
for mode in sft zero_shot; do
|
||||
python cosyvoice/bin/inference.py --mode $mode \
|
||||
--gpu 0 \
|
||||
--config conf/cosyvoice2.yaml \
|
||||
--prompt_data data/test-clean/parquet/data.list \
|
||||
--prompt_utt2data data/test-clean/parquet/utt2data.list \
|
||||
--tts_text `pwd`/tts_text.json \
|
||||
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
|
||||
--llm_model $pretrained_model_dir/llm.pt \
|
||||
--flow_model $pretrained_model_dir/flow.pt \
|
||||
--hifigan_model $pretrained_model_dir/hift.pt \
|
||||
--result_dir `pwd`/exp/cosyvoice/test-clean/$mode
|
||||
done
|
||||
fi
|
||||
|
||||
# train llm
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
@@ -86,7 +67,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
||||
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
||||
# NOTE will update llm/hift training later
|
||||
for model in llm flow; do
|
||||
for model in llm flow hifigan; do
|
||||
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
||||
cosyvoice/bin/train.py \
|
||||
|
||||
123
examples/libritts/cosyvoice2/run_dpo.sh
Normal file
123
examples/libritts/cosyvoice2/run_dpo.sh
Normal file
@@ -0,0 +1,123 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
||||
. ./path.sh || exit 1;
|
||||
|
||||
stage=-1
|
||||
stop_stage=3
|
||||
|
||||
data_url=www.openslr.org/resources/60
|
||||
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
|
||||
pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
echo "Data Download"
|
||||
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
|
||||
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mkdir -p data/$x
|
||||
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "Prepare negative samples using CosyVoice2-0.5B, this is also our reference model.
|
||||
Here we use CosyVoice2-0.5B generated audio as reject sample for simplicity, you can use metric like wer/similarity."
|
||||
for x in train-clean-100 train-clean-360 train-other-500; do
|
||||
mkdir -p data/${x}_reject
|
||||
python local/prepare_reject_sample.py --src_dir data/$x --des_dir data/${x}_reject --ref_model $pretrained_model_dir
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
tools/extract_embedding.py --dir data/$x \
|
||||
--onnx_path $pretrained_model_dir/campplus.onnx
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 train-clean-100_reject train-clean-360_reject dev-clean dev-other test-clean test-other; do
|
||||
tools/extract_speech_token.py --dir data/$x \
|
||||
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mkdir -p data/$x/parquet
|
||||
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||
--num_processes 10 \
|
||||
--dpo \
|
||||
--src_dir data/$x \
|
||||
--des_dir data/$x/parquet
|
||||
done
|
||||
fi
|
||||
|
||||
# train llm
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
job_id=1986
|
||||
dist_backend="nccl"
|
||||
num_workers=2
|
||||
prefetch=100
|
||||
train_engine=torch_ddp
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
|
||||
if [ $train_engine == 'deepspeed' ]; then
|
||||
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
||||
fi
|
||||
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
||||
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
||||
# NOTE only llm supports dpo
|
||||
for model in llm; do
|
||||
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
||||
cosyvoice/bin/train.py \
|
||||
--train_engine $train_engine \
|
||||
--config conf/cosyvoice2.yaml \
|
||||
--train_data data/train.data.list \
|
||||
--cv_data data/dev.data.list \
|
||||
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
|
||||
--model $model \
|
||||
--checkpoint $pretrained_model_dir/$model.pt \
|
||||
--ref_model $pretrained_model_dir/llm.pt \
|
||||
--model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \
|
||||
--tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \
|
||||
--ddp.dist_backend $dist_backend \
|
||||
--num_workers ${num_workers} \
|
||||
--prefetch ${prefetch} \
|
||||
--pin_memory \
|
||||
--use_amp \
|
||||
--dpo \
|
||||
--deepspeed_config ./conf/ds_stage2.json \
|
||||
--deepspeed.save_states model+optimizer
|
||||
done
|
||||
fi
|
||||
|
||||
# average model
|
||||
average_num=5
|
||||
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||
for model in llm flow hifigan; do
|
||||
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
|
||||
echo "do model average and final checkpoint is $decode_checkpoint"
|
||||
python cosyvoice/bin/average_model.py \
|
||||
--dst_model $decode_checkpoint \
|
||||
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
|
||||
--num ${average_num} \
|
||||
--val_best
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
||||
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
||||
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
||||
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
||||
fi
|
||||
@@ -1,5 +0,0 @@
|
||||
{
|
||||
"1089_134686_000002_000000": [
|
||||
"hello, my name is Jack. What is your name?"
|
||||
]
|
||||
}
|
||||
1
examples/libritts/cosyvoice2/tts_text.json
Symbolic link
1
examples/libritts/cosyvoice2/tts_text.json
Symbolic link
@@ -0,0 +1 @@
|
||||
../cosyvoice/tts_text.json
|
||||
1
examples/magicdata-read/cosyvoice/conf
Symbolic link
1
examples/magicdata-read/cosyvoice/conf
Symbolic link
@@ -0,0 +1 @@
|
||||
../../libritts/cosyvoice/conf
|
||||
@@ -1,203 +0,0 @@
|
||||
# set random seed, so that you may reproduce your result.
|
||||
__set_seed1: !apply:random.seed [1986]
|
||||
__set_seed2: !apply:numpy.random.seed [1986]
|
||||
__set_seed3: !apply:torch.manual_seed [1986]
|
||||
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
||||
|
||||
# fixed params
|
||||
sample_rate: 22050
|
||||
text_encoder_input_size: 512
|
||||
llm_input_size: 1024
|
||||
llm_output_size: 1024
|
||||
spk_embed_dim: 192
|
||||
|
||||
# model params
|
||||
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
||||
# for system/third_party class/function, we do not require this.
|
||||
llm: !new:cosyvoice.llm.llm.TransformerLM
|
||||
text_encoder_input_size: !ref <text_encoder_input_size>
|
||||
llm_input_size: !ref <llm_input_size>
|
||||
llm_output_size: !ref <llm_output_size>
|
||||
text_token_size: 51866 # change to 60515 if you want to train with CosyVoice-300M-25Hz recipe
|
||||
speech_token_size: 4096
|
||||
length_normalized_loss: True
|
||||
lsm_weight: 0
|
||||
spk_embed_dim: !ref <spk_embed_dim>
|
||||
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
||||
input_size: !ref <text_encoder_input_size>
|
||||
output_size: 1024
|
||||
attention_heads: 8
|
||||
linear_units: 2048
|
||||
num_blocks: 3
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
normalize_before: True
|
||||
input_layer: 'linear'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
use_cnn_module: False
|
||||
macaron_style: False
|
||||
use_dynamic_chunk: False
|
||||
use_dynamic_left_chunk: False
|
||||
static_chunk_size: 1
|
||||
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
|
||||
input_size: !ref <llm_input_size>
|
||||
output_size: !ref <llm_output_size>
|
||||
attention_heads: 8
|
||||
linear_units: 2048
|
||||
num_blocks: 7
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: 'linear_legacy'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
static_chunk_size: 1
|
||||
sampling: !name:cosyvoice.utils.common.ras_sampling
|
||||
top_p: 0.8
|
||||
top_k: 25
|
||||
win_size: 10
|
||||
tau_r: 0.1
|
||||
|
||||
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
||||
input_size: 512
|
||||
output_size: 80
|
||||
spk_embed_dim: !ref <spk_embed_dim>
|
||||
output_type: 'mel'
|
||||
vocab_size: 4096
|
||||
input_frame_rate: 50 # change to 25 if you want to train with CosyVoice-300M-25Hz recipe
|
||||
only_mask_loss: True
|
||||
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
||||
output_size: 512
|
||||
attention_heads: 4
|
||||
linear_units: 1024
|
||||
num_blocks: 3
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.1
|
||||
normalize_before: True
|
||||
input_layer: 'linear'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
input_size: 512
|
||||
use_cnn_module: False
|
||||
macaron_style: False
|
||||
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
|
||||
channels: 80
|
||||
sampling_ratios: [1, 1, 1, 1]
|
||||
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
|
||||
in_channels: 240
|
||||
n_spks: 1
|
||||
spk_emb_dim: 80
|
||||
cfm_params: !new:omegaconf.DictConfig
|
||||
content:
|
||||
sigma_min: 1e-06
|
||||
solver: 'euler'
|
||||
t_scheduler: 'cosine'
|
||||
training_cfg_rate: 0.2
|
||||
inference_cfg_rate: 0.7
|
||||
reg_loss_type: 'l1'
|
||||
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
|
||||
in_channels: 320
|
||||
out_channels: 80
|
||||
channels: [256, 256]
|
||||
dropout: 0.0
|
||||
attention_head_dim: 64
|
||||
n_blocks: 4
|
||||
num_mid_blocks: 8
|
||||
num_heads: 8
|
||||
act_fn: 'gelu'
|
||||
|
||||
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
||||
in_channels: 80
|
||||
base_channels: 512
|
||||
nb_harmonics: 8
|
||||
sampling_rate: !ref <sample_rate>
|
||||
nsf_alpha: 0.1
|
||||
nsf_sigma: 0.003
|
||||
nsf_voiced_threshold: 10
|
||||
upsample_rates: [8, 8]
|
||||
upsample_kernel_sizes: [16, 16]
|
||||
istft_params:
|
||||
n_fft: 16
|
||||
hop_len: 4
|
||||
resblock_kernel_sizes: [3, 7, 11]
|
||||
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
source_resblock_kernel_sizes: [7, 11]
|
||||
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
|
||||
lrelu_slope: 0.1
|
||||
audio_limit: 0.99
|
||||
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
||||
num_class: 1
|
||||
in_channels: 80
|
||||
cond_channels: 512
|
||||
|
||||
# processor functions
|
||||
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
||||
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
|
||||
multilingual: True
|
||||
num_languages: 100
|
||||
language: 'en'
|
||||
task: 'transcribe'
|
||||
allowed_special: 'all'
|
||||
tokenize: !name:cosyvoice.dataset.processor.tokenize
|
||||
get_tokenizer: !ref <get_tokenizer>
|
||||
allowed_special: !ref <allowed_special>
|
||||
filter: !name:cosyvoice.dataset.processor.filter
|
||||
max_length: 40960
|
||||
min_length: 0
|
||||
token_max_length: 200
|
||||
token_min_length: 1
|
||||
resample: !name:cosyvoice.dataset.processor.resample
|
||||
resample_rate: !ref <sample_rate>
|
||||
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||
n_fft: 1024
|
||||
num_mels: 80
|
||||
sampling_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
win_size: 1024
|
||||
fmin: 0
|
||||
fmax: 8000
|
||||
center: False
|
||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
||||
feat_extractor: !ref <feat_extractor>
|
||||
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
||||
normalize: True
|
||||
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
||||
shuffle_size: 1000
|
||||
sort: !name:cosyvoice.dataset.processor.sort
|
||||
sort_size: 500 # sort_size should be less than shuffle_size
|
||||
batch: !name:cosyvoice.dataset.processor.batch
|
||||
batch_type: 'dynamic'
|
||||
max_frames_in_batch: 12000
|
||||
padding: !name:cosyvoice.dataset.processor.padding
|
||||
use_spk_embedding: False # change to True during sft
|
||||
|
||||
# dataset processor pipeline
|
||||
data_pipeline: [
|
||||
!ref <parquet_opener>,
|
||||
!ref <tokenize>,
|
||||
!ref <filter>,
|
||||
!ref <resample>,
|
||||
!ref <compute_fbank>,
|
||||
!ref <parse_embedding>,
|
||||
!ref <shuffle>,
|
||||
!ref <sort>,
|
||||
!ref <batch>,
|
||||
!ref <padding>,
|
||||
]
|
||||
|
||||
# train conf
|
||||
train_conf:
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.002 # change to 0.001 if you want to train flow from scratch
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
max_epoch: 200
|
||||
grad_clip: 5
|
||||
accum_grad: 2
|
||||
log_interval: 100
|
||||
save_per_step: -1
|
||||
@@ -1,42 +0,0 @@
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 100,
|
||||
"gradient_clipping": 5,
|
||||
"fp16": {
|
||||
"enabled": false,
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 16,
|
||||
"loss_scale_window": 256,
|
||||
"hysteresis": 2,
|
||||
"consecutive_hysteresis": false,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": false
|
||||
},
|
||||
"zero_force_ds_cpu_optimizer": false,
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "none",
|
||||
"pin_memory": true
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": false,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients" : true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 0.001,
|
||||
"weight_decay": 0.0001,
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH
|
||||
1
examples/magicdata-read/cosyvoice/path.sh
Symbolic link
1
examples/magicdata-read/cosyvoice/path.sh
Symbolic link
@@ -0,0 +1 @@
|
||||
../../libritts/cosyvoice/path.sh
|
||||
@@ -51,23 +51,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
# inference
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
|
||||
for mode in sft zero_shot; do
|
||||
python cosyvoice/bin/inference.py --mode $mode \
|
||||
--gpu 0 \
|
||||
--config conf/cosyvoice.yaml \
|
||||
--prompt_data data/test/parquet/data.list \
|
||||
--prompt_utt2data data/test/parquet/utt2data.list \
|
||||
--tts_text `pwd`/tts_text.json \
|
||||
--llm_model $pretrained_model_dir/llm.pt \
|
||||
--flow_model $pretrained_model_dir/flow.pt \
|
||||
--hifigan_model $pretrained_model_dir/hift.pt \
|
||||
--result_dir `pwd`/exp/cosyvoice/test/$mode
|
||||
done
|
||||
fi
|
||||
|
||||
# train llm
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
@@ -83,7 +66,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
fi
|
||||
cp data/train/parquet/data.list data/train.data.list
|
||||
cp data/dev/parquet/data.list data/dev.data.list
|
||||
for model in llm flow; do
|
||||
for model in llm flow hifigan; do
|
||||
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
||||
cosyvoice/bin/train.py \
|
||||
@@ -99,11 +82,26 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
--num_workers ${num_workers} \
|
||||
--prefetch ${prefetch} \
|
||||
--pin_memory \
|
||||
--use_amp \
|
||||
--deepspeed_config ./conf/ds_stage2.json \
|
||||
--deepspeed.save_states model+optimizer
|
||||
done
|
||||
fi
|
||||
|
||||
# average model
|
||||
average_num=5
|
||||
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||
for model in llm flow hifigan; do
|
||||
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
|
||||
echo "do model average and final checkpoint is $decode_checkpoint"
|
||||
python cosyvoice/bin/average_model.py \
|
||||
--dst_model $decode_checkpoint \
|
||||
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
|
||||
--num ${average_num} \
|
||||
--val_best
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
||||
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684
|
||||
conformer==0.3.2
|
||||
deepspeed==0.14.2; sys_platform == 'linux'
|
||||
deepspeed==0.15.1; sys_platform == 'linux'
|
||||
diffusers==0.29.0
|
||||
fastapi==0.115.6
|
||||
fastapi-cli==0.0.4
|
||||
@@ -36,5 +36,5 @@ torch==2.3.1
|
||||
torchaudio==2.3.1
|
||||
transformers==4.40.1
|
||||
uvicorn==0.30.0
|
||||
WeTextProcessing==1.0.3
|
||||
wetext==0.0.4
|
||||
wget==3.2
|
||||
|
||||
@@ -34,10 +34,10 @@ logging.basicConfig(level=logging.DEBUG,
|
||||
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
||||
def __init__(self, args):
|
||||
try:
|
||||
self.cosyvoice = CosyVoice(args.model_dir)
|
||||
self.cosyvoice = CosyVoice(args.model_dir, trt_concurrent=args.max_conc)
|
||||
except Exception:
|
||||
try:
|
||||
self.cosyvoice = CosyVoice2(args.model_dir)
|
||||
self.cosyvoice = CosyVoice2(args.model_dir, trt_concurrent=args.max_conc)
|
||||
except Exception:
|
||||
raise TypeError('no valid model_type!')
|
||||
logging.info('grpc service initialized')
|
||||
|
||||
@@ -34,7 +34,9 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
|
||||
spk_list = [utt2spk[utt] for utt in utt_list]
|
||||
uttembedding_list = [utt2embedding[utt] for utt in utt_list]
|
||||
spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
|
||||
speech_token_list = [utt2speech_token[utt] for utt in utt_list]
|
||||
speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
|
||||
if args.dpo:
|
||||
reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
|
||||
|
||||
# 保存到parquet,utt2parquet_file,spk2parquet_file
|
||||
df = pd.DataFrame()
|
||||
@@ -46,6 +48,8 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
|
||||
df['utt_embedding'] = uttembedding_list
|
||||
df['spk_embedding'] = spkembedding_list
|
||||
df['speech_token'] = speech_token_list
|
||||
if args.dpo:
|
||||
df['reject_speech_token'] = reject_speech_token_list
|
||||
df.to_parquet(parquet_file)
|
||||
with open(utt2parquet_file, 'w') as f:
|
||||
json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
|
||||
@@ -68,6 +72,10 @@ if __name__ == "__main__":
|
||||
type=str)
|
||||
parser.add_argument('--des_dir',
|
||||
type=str)
|
||||
parser.add_argument('--dpo',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use Direct Preference Optimization')
|
||||
args = parser.parse_args()
|
||||
|
||||
utt2wav, utt2text, utt2spk = {}, {}, {}
|
||||
@@ -86,6 +94,8 @@ if __name__ == "__main__":
|
||||
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
|
||||
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
|
||||
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
|
||||
if args.dpo:
|
||||
utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir))
|
||||
utts = list(utt2wav.keys())
|
||||
|
||||
# Using process pool to speedup
|
||||
|
||||
23
vllm_example.py
Normal file
23
vllm_example.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import sys
|
||||
sys.path.append('third_party/Matcha-TTS')
|
||||
from vllm import ModelRegistry
|
||||
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
|
||||
ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
|
||||
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
from cosyvoice.utils.common import set_all_random_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
|
||||
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
||||
for i in tqdm(range(100)):
|
||||
set_all_random_seed(i)
|
||||
for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user