add cosyvoice code

This commit is contained in:
lyuxiang.lx
2024-07-04 21:15:12 +08:00
parent 06984ac149
commit 076829ab84
64 changed files with 8428 additions and 18 deletions

76
CODE_OF_CONDUCT.md Normal file
View File

@@ -0,0 +1,76 @@
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at mikelei@mobvoi.com. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

214
LICENSE
View File

@@ -1,21 +1,201 @@
MIT License
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
Copyright (c) 2024 FunAudioLLM
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
1. Definitions.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

146
README.md
View File

@@ -1 +1,145 @@
# CosyVoice
# CosyVoice
For `CosyVoice`, visit [CosyVoice repo](https://https://github.com/FunAudioLLM/CosyVoice) and [CosyVoice space](https://www.modelscope.cn/studios/iic/CosyVoice-300M).
For `SenseVoice`, visit [SenseVoice repo](https://https://github.com/FunAudioLLM/SenseVoice) and [SenseVoice space](https://www.modelscope.cn/studios/iic/SenseVoice).
## Install
**Clone and install**
- Clone the repo
``` sh
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
# If you failed to clone submodule due to network failures, please run following command until success
cd CosyVoice
git submodule update --init --recursive
```
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
- Create Conda env:
``` sh
conda create -n cosyvoice python=3.8
conda activate cosyvoice
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# If you encounter sox compatibility issues
# ubuntu
sudo apt-get install sox libsox-dev
# centos
sudo yum install sox sox-devel
```
**Model download**
We strongly recommand that you download our pretrained `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `speech_kantts_ttsfrd` resource.
If you are expert in this field, and you are only interested in training your own CosyVoice model from scratch, you can skip this step.
``` python
# SDK模型下载
from modelscope import snapshot_download
snapshot_download('speech_tts/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('speech_tts/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('speech_tts/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('speech_tts/speech_kantts_ttsfrd', local_dir='pretrained_models/speech_kantts_ttsfrd')
```
``` sh
# git模型下载请确保已安装git lfs
mkdir -p pretrained_models
git clone https://www.modelscope.cn/speech_tts/CosyVoice-300M.git pretrained_models/CosyVoice-300M
git clone https://www.modelscope.cn/speech_tts/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
git clone https://www.modelscope.cn/speech_tts/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
git clone https://www.modelscope.cn/speech_tts/speech_kantts_ttsfrd.git pretrained_models/speech_kantts_ttsfrd
```
Unzip `ttsfrd` resouce and install `ttsfrd` package
``` sh
cd pretrained_models/speech_kantts_ttsfrd/
unzip resource.zip -d .
pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
```
**Basic Usage**
For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
For sft inference, please use `CosyVoice-300M-SFT` model.
For instruct inference, please use `CosyVoice-300M-Instruct` model.
First, add `third_party/AcademiCodec` and `third_party/Matcha-TTS` to your `PYTHONPATH`.
``` sh
export PYTHONPATH=third_party/AcademiCodec:third_party/Matcha-TTS
```
``` python
from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.utils.file_utils import load_wav
import torchaudio
cosyvoice = CosyVoice('speech_tts/CosyVoice-300M-SFT')
# sft usage
print(cosyvoice.list_avaliable_spks())
output = cosyvoice.inference_sft('你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?', '中文女')
torchaudio.save('sft.wav', output['tts_speech'], 22050)
cosyvoice = CosyVoice('speech_tts/CosyVoice-300M')
# zero_shot usage
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k)
torchaudio.save('zero_shot.wav', output['tts_speech'], 22050)
# cross_lingual usage
prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k)
torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050)
cosyvoice = CosyVoice('speech_tts/CosyVoice-300M-Instruct')
# instruct usage
output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
torchaudio.save('instruct.wav', output['tts_speech'], 22050)
```
**Start web demo**
You can use our web demo page to get familiar with CosyVoice quickly.
We support sft/zero_shot/cross_lingual/instruct inference in web demo.
Please see the demo website for details.
``` python
# change speech_tts/CosyVoice-300M-SFT for sft inference, or speech_tts/CosyVoice-300M-Instruct for instruct inference
python3 webui.py --port 50000 --model_dir speech_tts/CosyVoice-300M
```
**Advanced Usage**
For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
You can get familiar with CosyVoice following this recipie.
**Build for deployment**
Optionally, if you want to use grpc for service deployment,
you can run following steps. Otherwise, you can just ignore this step.
``` sh
cd runtime/python
docker build -t cosyvoice:v1.0 .
# change speech_tts/CosyVoice-300M to speech_tts/CosyVoice-300M-Instruct if you want to use instruct inference
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python && python3 server.py --port 50000 --max_conc 4 --model_dir speech_tts/CosyVoice-300M && sleep infinity"
python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
```
## Discussion & Communication
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
You can also scan the QR code to join our officla Dingding chat group.
<img src="./asset/dingding.png" width="250px">
## Acknowledge
1. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
2. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
3. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).

BIN
asset/dingding.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

0
cosyvoice/__init__.py Normal file
View File

114
cosyvoice/bin/inference.py Normal file
View File

@@ -0,0 +1,114 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from cosyvoice.cli.model import CosyVoiceModel
from cosyvoice.dataset.dataset import Dataset
def get_args():
parser = argparse.ArgumentParser(description='inference with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
parser.add_argument('--tts_text', required=True, help='tts input file')
parser.add_argument('--llm_model', required=True, help='llm model file')
parser.add_argument('--flow_model', required=True, help='flow model file')
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--mode',
default='sft',
choices=['sft', 'zero_shot'],
help='inference mode')
parser.add_argument('--result_dir', required=True, help='asr result file')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
model.load(args.llm_model, args.flow_model, args.hifigan_model)
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
with torch.no_grad():
for batch_idx, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
text = batch["text"]
text_token = batch["text_token"].to(device)
text_token_len = batch["text_token_len"].to(device)
tts_text = batch["tts_text"]
tts_index = batch["tts_index"]
tts_text_token = batch["tts_text_token"].to(device)
tts_text_token_len = batch["tts_text_token_len"].to(device)
speech_token = batch["speech_token"].to(device)
speech_token_len = batch["speech_token_len"].to(device)
speech_feat = batch["speech_feat"].to(device)
speech_feat_len = batch["speech_feat_len"].to(device)
utt_embedding = batch["utt_embedding"].to(device)
spk_embedding = batch["spk_embedding"].to(device)
if args.mode == 'sft':
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
else:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': text_token, 'prompt_text_len': text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
model_output = model.inference(**model_input)
tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush()
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
if __name__ == '__main__':
main()

137
cosyvoice/bin/train.py Normal file
View File

@@ -0,0 +1,137 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import torch
import torch.distributed as dist
import deepspeed
from hyperpyyaml import load_hyperpyyaml
from torch.distributed.elastic.multiprocessing.errors import record
from cosyvoice.utils.executor import Executor
from cosyvoice.utils.train_utils import (
init_distributed,
init_dataset_and_dataloader,
init_optimizer_and_scheduler,
init_summarywriter, save_model,
wrap_cuda_model, check_modify_and_save_config)
def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--train_engine',
default='torch_ddp',
choices=['torch_ddp', 'deepspeed'],
help='Engine for paralleled training')
parser.add_argument('--model', required=True, help='model which will be trained')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
choices=['model_only', 'model+optimizer'],
help='save model/optimizer states')
parser.add_argument('--timeout',
default=30,
type=int,
help='timeout (in seconds) of cosyvoice_join. ' +
'30s for aishell & 300s for wenetspeech')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
@record
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
configs['train_conf'].update(vars(args))
# Init env for ddp
init_distributed(args)
# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs)
# Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
# Tensorboard summary
writer = init_summarywriter(args)
# load checkpoint
model = configs[args.model]
if args.checkpoint is not None:
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
# Dispatch model from cpu to gpu
model = wrap_cuda_model(args, model)
# Get optimizer & scheduler
model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
# Save init checkpoints
info_dict = deepcopy(configs['train_conf'])
save_model(model, 'init', info_dict)
# Get executor
executor = Executor()
# Start training loop
for epoch in range(info_dict['max_epoch']):
executor.epoch = epoch
train_dataset.set_epoch(epoch)
dist.barrier()
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
dist.destroy_process_group(group_join)
if __name__ == '__main__':
main()

View File

View File

@@ -0,0 +1,83 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel
class CosyVoice:
def __init__(self, model_dir):
instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
configs = load_hyperpyyaml(f)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir),
instruct,
configs['allowed_special'])
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
del configs
def list_avaliable_spks(self):
spks = list(self.frontend.spk2info.keys())
return spks
def inference_sft(self, tts_text, spk_id):
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_sft(i, spk_id)
model_output = self.model.inference(**model_input)
tts_speeches.append(model_output['tts_speech'])
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
model_output = self.model.inference(**model_input)
tts_speeches.append(model_output['tts_speech'])
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
def inference_cross_lingual(self, tts_text, prompt_speech_16k):
if self.frontend.instruct is True:
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
model_output = self.model.inference(**model_input)
tts_speeches.append(model_output['tts_speech'])
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
def inference_instruct(self, tts_text, spk_id, instruct_text):
if self.frontend.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
tts_speeches = []
for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
model_output = self.model.inference(**model_input)
tts_speeches.append(model_output['tts_speech'])
return {'tts_speech': torch.concat(tts_speeches, dim=1)}

146
cosyvoice/cli/frontend.py Normal file
View File

@@ -0,0 +1,146 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import onnxruntime
import torch
import numpy as np
import whisper
from typing import Callable
import torchaudio.compliance.kaldi as kaldi
import torchaudio
import os
import inflect
import ttsfrd
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
class CosyVoiceFrontEnd:
def __init__(self,
get_tokenizer: Callable,
feat_extractor: Callable,
campplus_model: str,
speech_tokenizer_model: str,
spk2info: str = '',
instruct: bool = False,
allowed_special: str = 'all'):
self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"])
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device)
self.instruct = instruct
self.allowed_special = allowed_special
self.inflect_parser = inflect.engine()
self.frd = ttsfrd.TtsFrontendEngine()
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
assert self.frd.initialize('{}/../../pretrained_models/speech_kantts_ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyin')
self.frd.enable_pinyin_mix(True)
self.frd.set_breakmodel_index(1)
def _extract_text_token(self, text):
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
return text_token, text_token_len
def _extract_speech_token(self, speech):
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
return speech_token, speech_token_len
def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device)
return embedding
def _extract_speech_feat(self, speech):
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
return speech_feat, speech_feat_len
def text_normalize(self, text, split=True):
text = text.strip()
if contains_chinese(text):
text = self.frd.get_frd_extra_info(text, 'input').replace("\n", "")
text = replace_blank(text)
text = replace_corner_mark(text)
text = text.replace(".", "")
text = text.replace(" - ", "")
text = remove_bracket(text)
texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20,
comma_split=False)]
else:
text = spell_out_number(text, self.inflect_parser)
texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20,
comma_split=False)]
if split is False:
return text
return texts
def frontend_sft(self, tts_text, spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
embedding = self.spk2info[spk_id]['embedding']
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
embedding = self._extract_spk_embedding(prompt_speech_16k)
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input
def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
# in cross lingual mode, we remove prompt in llm
del model_input['prompt_text']
del model_input['prompt_text_len']
del model_input['llm_prompt_speech_token']
del model_input['llm_prompt_speech_token_len']
return model_input
def frontend_instruct(self, tts_text, spk_id, instruct_text):
model_input = self.frontend_sft(tts_text, spk_id)
# in instruct mode, we remove spk_embedding in llm due to information leakage
del model_input['llm_embedding']
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
model_input['prompt_text'] = instruct_text_token
model_input['prompt_text_len'] = instruct_text_token_len
return model_input

59
cosyvoice/cli/model.py Normal file
View File

@@ -0,0 +1,59 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class CosyVoiceModel:
def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
self.llm.to(self.device).eval()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
self.flow.to(self.device).eval()
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
self.hift.to(self.device).eval()
def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
tts_speech_token = self.llm.inference(text=text.to(self.device),
text_len=text_len.to(self.device),
prompt_text=prompt_text.to(self.device),
prompt_text_len=prompt_text_len.to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
embedding=llm_embedding.to(self.device),
beam_size=1,
sampling=25,
max_token_text_ratio=30,
min_token_text_ratio=3)
tts_mel = self.flow.inference(token=tts_speech_token,
token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
prompt_token=flow_prompt_speech_token.to(self.device),
prompt_token_len=flow_prompt_speech_token_len.to(self.device),
prompt_feat=prompt_speech_feat.to(self.device),
prompt_feat_len=prompt_speech_feat_len.to(self.device),
embedding=flow_embedding.to(self.device))
tts_speech = self.hift.inference(mel=tts_mel).cpu()
return {'tts_speech': tts_speech}

View File

View File

@@ -0,0 +1,160 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import json
import math
from functools import partial
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from cosyvoice.utils.file_utils import read_lists, read_json_lists
class Processor(IterableDataset):
def __init__(self, source, f, *args, **kw):
assert callable(f)
self.source = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self.source is not None
assert callable(self.f)
return self.f(iter(self.source), *self.args, **self.kw)
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self, data):
""" Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data = list(range(len(data)))
# force datalist even
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
if len(data) < self.world_size:
data = data * math.ceil(self.world_size / len(data))
data = data[:self.world_size]
data = data[self.rank::self.world_size]
if len(data) < self.num_workers:
data = data * math.ceil(self.num_workers / len(data))
data = data[:self.num_workers]
data = data[self.worker_id::self.num_workers]
return data
class DataList(IterableDataset):
def __init__(self, lists, shuffle=True, partition=True):
self.lists = lists
self.sampler = DistributedSampler(shuffle, partition)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
for index in indexes:
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
def Dataset(data_list_file,
data_pipeline,
mode='train',
shuffle=True,
partition=True,
tts_file='',
prompt_utt2data=''):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_type(str): raw/shard
tokenizer (BaseTokenizer): tokenizer to tokenize
partition(bool): whether to do data partition in terms of rank
"""
assert mode in ['train', 'inference']
lists = read_lists(data_list_file)
if mode == 'inference':
with open(tts_file) as f:
tts_data = json.load(f)
utt2lists = read_json_lists(prompt_utt2data)
# filter unnecessary file in inference mode
lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
dataset = DataList(lists,
shuffle=shuffle,
partition=partition)
if mode == 'inference':
# map partial arg tts_data in inference mode
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
for func in data_pipeline:
dataset = Processor(dataset, func, mode=mode)
return dataset

View File

@@ -0,0 +1,366 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import pyarrow.parquet as pq
from io import BytesIO
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
torchaudio.set_audio_backend('soundfile')
torchaudio.utils.sox_utils.set_buffer_size(16500)
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
def parquet_opener(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
url = sample['src']
try:
df = pq.read_table(url).to_pandas()
for i in range(len(df)):
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
continue
sample.update(dict(df.loc[i]))
if mode == 'train':
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
else:
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
yield {**sample, 'tts_index': index, 'tts_text': text}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
def filter(data,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
del sample['audio_data']
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['text_token']) < token_min_length:
continue
if len(sample['text_token']) > token_max_length:
continue
if len(sample['speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio:
continue
if len(sample['text_token']) / num_frames > max_output_input_ratio:
continue
yield sample
def resample(data, resample_rate=22050, mode='train'):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
sample_rate = sample['sample_rate']
waveform = sample['speech']
if sample_rate != resample_rate:
if sample_rate < resample_rate:
continue
sample['sample_rate'] = resample_rate
sample['speech'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
max_val = sample['speech'].abs().max()
if max_val > 1:
sample['speech'] /= max_val
yield sample
def compute_fbank(data,
feat_extractor,
mode='train'):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
sample['speech_feat'] = mat
del sample['speech']
yield sample
def parse_embedding(data, normalize, mode='train'):
""" Parse utt_embedding/spk_embedding
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
sample['spk_embedding'] = torch.stack([torch.tensor(i, dtype=torch.float32) for i in sample['spk_embedding']], dim=0).mean(dim=0)
if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
yield sample
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
tokenizer = get_tokenizer()
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if mode == 'inference':
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
yield sample
def shuffle(data, shuffle_size=10000, mode='train'):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, mode='train'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['speech_feat'].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['speech_feat'].size(0))
for x in buf:
yield x
def static_batch(data, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
assert 'speech_feat' in sample
assert isinstance(sample['speech_feat'], torch.Tensor)
new_sample_frames = sample['speech_feat'].size(0)
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
""" Wrapper for static/dynamic batch
"""
if mode == 'inference':
return static_batch(data, 1)
else:
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data, mode='train'):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
dtype=torch.int32)
order = torch.argsort(speech_feat_len, descending=True)
utts = [sample[i]['utt'] for i in order]
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
"utts": utts,
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_feat": speech_feat,
"speech_feat_len": speech_feat_len,
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}
if mode == 'inference':
tts_text = [sample[i]['tts_text'] for i in order]
tts_index = [sample[i]['tts_index'] for i in order]
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
batch.update({'tts_text': tts_text,
'tts_index': tts_index,
'tts_text_token': tts_text_token,
'tts_text_token_len': tts_text_token_len})
yield batch

222
cosyvoice/flow/decoder.py Executable file
View File

@@ -0,0 +1,222 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from einops import pack, rearrange, repeat
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from matcha.models.components.transformer import BasicTransformerBlock
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for i in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask

135
cosyvoice/flow/flow.py Normal file
View File

@@ -0,0 +1,135 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask
class MaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
encoder: torch.nn.Module = None,
length_regulator: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.length_regulator = length_regulator
self.only_mask_loss = only_mask_loss
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['utt_embedding'].to(device)
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
feat_len = (token_len / 50 * 22050 / 256).int()
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
if prompt_feat.shape[1] != 0:
for i, j in enumerate(prompt_feat_len):
conds[i, :j] = prompt_feat[i]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
feat = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10
)
if prompt_feat.shape[1] != 0:
feat = feat[:, :, prompt_feat.shape[1]:]
return feat

131
cosyvoice/flow/flow_matching.py Executable file
View File

@@ -0,0 +1,131 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
class ConditionalCFM(BASECFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = self.estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,
torch.zeros_like(cond)
)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y

View File

@@ -0,0 +1,49 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch.nn as nn
from torch.nn import functional as F
from cosyvoice.utils.mask import make_pad_mask
class InterpolateRegulator(nn.Module):
def __init__(
self,
channels: int,
sampling_ratios: Tuple,
out_channels: int = None,
groups: int = 1,
):
super().__init__()
self.sampling_ratios = sampling_ratios
out_channels = out_channels or channels
model = nn.ModuleList([])
if len(sampling_ratios) > 0:
for _ in sampling_ratios:
module = nn.Conv1d(channels, channels, 3, 1, 1)
norm = nn.GroupNorm(groups, channels)
act = nn.Mish()
model.extend([module, norm, act])
model.append(
nn.Conv1d(channels, out_channels, 1, 1)
)
self.model = nn.Sequential(*model)
def forward(self, x, ylens=None):
# x in (B, T, D)
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
out = self.model(x).transpose(1, 2).contiguous()
olens = ylens
return out * mask, olens

View File

@@ -0,0 +1,55 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
class ConvRNNF0Predictor(nn.Module):
def __init__(self,
num_class: int = 1,
in_channels: int = 80,
cond_channels: int = 512
):
super().__init__()
self.num_class = num_class
self.condnet = nn.Sequential(
weight_norm(
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
)
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.condnet(x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))

View File

@@ -0,0 +1,391 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HIFI-GAN"""
import typing as tp
import numpy as np
from scipy.signal import get_window
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d
from torch.nn import ConvTranspose1d
from torch.nn.utils import remove_weight_norm
from torch.nn.utils import weight_norm
from torch.distributions.uniform import Uniform
from cosyvoice.transformer.activation import Snake
from academicodec.utils import get_padding
from academicodec.utils import init_weights
"""hifigan based generator implementation.
This code is modified from https://github.com/jik876/hifi-gan
,https://github.com/kan-bayashi/ParallelWaveGAN and
https://github.com/NVIDIA/BigVGAN
"""
class ResBlock(torch.nn.Module):
"""Residual block module in HiFiGAN/BigVGAN."""
def __init__(
self,
channels: int = 512,
kernel_size: int = 3,
dilations: tp.List[int] = [1, 3, 5],
):
super(ResBlock, self).__init__()
self.convs1 = nn.ModuleList()
self.convs2 = nn.ModuleList()
for dilation in dilations:
self.convs1.append(
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
padding=get_padding(kernel_size, dilation)
)
)
)
self.convs2.append(
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)
)
)
)
self.convs1.apply(init_weights)
self.convs2.apply(init_weights)
self.activations1 = nn.ModuleList([
Snake(channels, alpha_logscale=False)
for _ in range(len(self.convs1))
])
self.activations2 = nn.ModuleList([
Snake(channels, alpha_logscale=False)
for _ in range(len(self.convs2))
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for idx in range(len(self.convs1)):
xt = self.activations1[idx](x)
xt = self.convs1[idx](xt)
xt = self.activations2[idx](xt)
xt = self.convs2[idx](xt)
x = xt + x
return x
def remove_weight_norm(self):
for idx in range(len(self.convs1)):
remove_weight_norm(self.convs1[idx])
remove_weight_norm(self.convs2[idx])
class SineGen(torch.nn.Module):
""" Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0)
"""
def __init__(self, samp_rate, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0):
super(SineGen, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
def _f02uv(self, f0):
# generate uv signal
uv = (f0 > self.voiced_threshold).type(torch.float32)
return uv
@torch.no_grad()
def forward(self, f0):
"""
:param f0: [B, 1, sample_len], Hz
:return: [B, 1, sample_len]
"""
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
for i in range(self.harmonic_num + 1):
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
phase_vec[:, 0, :] = 0
# generate sine waveforms
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
# generate uv signal
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
sine_wavs = sine_wavs.transpose(1, 2)
uv = uv.transpose(1, 2)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
class HiFTGenerator(nn.Module):
"""
HiFTNet Generator: Neural Source Filter + ISTFTNet
https://arxiv.org/abs/2309.09493
"""
def __init__(
self,
in_channels: int = 80,
base_channels: int = 512,
nb_harmonics: int = 8,
sampling_rate: int = 22050,
nsf_alpha: float = 0.1,
nsf_sigma: float = 0.003,
nsf_voiced_threshold: float = 10,
upsample_rates: tp.List[int] = [8, 8],
upsample_kernel_sizes: tp.List[int] = [16, 16],
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
lrelu_slope: float = 0.1,
audio_limit: float = 0.99,
f0_predictor: torch.nn.Module = None,
):
super(HiFTGenerator, self).__init__()
self.out_channels = 1
self.nb_harmonics = nb_harmonics
self.sampling_rate = sampling_rate
self.istft_params = istft_params
self.lrelu_slope = lrelu_slope
self.audio_limit = audio_limit
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
Conv1d(in_channels, base_channels, 7, 1, padding=3)
)
# Up
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
base_channels // (2**i),
base_channels // (2**(i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
# Down
self.source_downs = nn.ModuleList()
self.source_resblocks = nn.ModuleList()
downsample_rates = [1] + upsample_rates[::-1][:-1]
downsample_cum_rates = np.cumprod(downsample_rates)
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
source_resblock_dilation_sizes)):
if u == 1:
self.source_downs.append(
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
)
else:
self.source_downs.append(
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
)
self.source_resblocks.append(
ResBlock(base_channels // (2 ** (i + 1)), k, d)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = base_channels // (2**(i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(ResBlock(ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.reflection_pad = nn.ReflectionPad1d((1, 0))
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
self.f0_predictor = f0_predictor
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
har_source, _, _ = self.m_source(f0)
return har_source.transpose(1, 2)
def _stft(self, x):
spec = torch.stft(
x,
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
return_complex=True)
spec = torch.view_as_real(spec) # [B, F, TT, 2]
return spec[..., 0], spec[..., 1]
def _istft(self, magnitude, phase):
magnitude = torch.clip(magnitude, max=1e2)
real = magnitude * torch.cos(phase)
img = magnitude * torch.sin(phase)
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
return inverse_transform
def forward(self, x: torch.Tensor) -> torch.Tensor:
f0 = self.f0_predictor(x)
s = self._f02source(f0)
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, self.lrelu_slope)
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
# fusion
si = self.source_downs[i](s_stft)
si = self.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
x = self._istft(magnitude, phase)
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
self.source_module.remove_weight_norm()
for l in self.source_downs:
remove_weight_norm(l)
for l in self.source_resblocks:
l.remove_weight_norm()
@torch.inference_mode()
def inference(self, mel: torch.Tensor) -> torch.Tensor:
return self.forward(x=mel)

206
cosyvoice/llm/llm.py Normal file
View File

@@ -0,0 +1,206 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Union
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
from cosyvoice.utils.common import IGNORE_ID
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy
class TransformerLM(torch.nn.Module):
def __init__(
self,
text_encoder_input_size: int,
llm_input_size: int,
llm_output_size: int,
text_token_size: int,
speech_token_size: int,
text_encoder: torch.nn.Module,
llm: torch.nn.Module,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
spk_embed_dim: int = 192,
):
super().__init__()
self.llm_input_size = llm_input_size
self.speech_token_size = speech_token_size
# 1. build text token inputs related modules
self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
self.text_encoder = text_encoder
self.text_encoder_affine_layer = nn.Linear(
self.text_encoder.output_size(),
llm_input_size
)
# 2. build speech token language model related modules
self.sos_eos = 0
self.task_id = 1
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
self.criterion_ce = LabelSmoothingLoss(
size=speech_token_size + 1,
padding_idx=IGNORE_ID,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
def encode(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
):
encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out = self.text_encoder_affine_layer(encoder_out)
return encoder_out, encoder_out_lens
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
return lm_input, lm_input_len
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
embedding = batch['utt_embedding'].to(device)
# 1. prepare llm_target
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
# 1. encode text_token
text_token = self.text_embedding(text_token)
text_token, text_token_len = self.encode(text_token, text_token_len)
# 2. embedding projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(1)
# 3. eos and task_id
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 4. encode speech_token
speech_token = self.speech_embedding(speech_token)
# 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)
# 6. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target)
acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}
def sampling_ids(
self,
weighted_scores: torch.Tensor,
sampling: Union[bool, int, float] = True,
beam_size: int = 1,
ignore_eos: bool = True,
):
while True:
prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
top_ids = prob.multinomial(beam_size, replacement=True)
top_ids = indices[top_ids]
if (not ignore_eos) or (self.speech_token_size not in top_ids):
break
return top_ids
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
beam_size: int = 1,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> torch.Tensor:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
text = self.text_embedding(text)
# 1. encode text
text, text_len = self.encode(text, text_len)
# 2. encode embedding
if embedding.shape[0] != 0:
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(dim=1)
else:
embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode
out_tokens = []
offset = 0
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
for i in range(max_len):
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
out_tokens.append(top_ids)
offset += lm_input.size(1)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
return torch.tensor([out_tokens], dtype=torch.int64, device=device)

View File

View File

@@ -0,0 +1,84 @@
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
# 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Swish() activation function for Conformer."""
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return Swish activation function."""
return x * torch.sigmoid(x)
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x

View File

@@ -0,0 +1,326 @@
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Head Attention layer definition."""
import math
from typing import Tuple
import torch
from torch import nn
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor, size
(#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(
self,
value: torch.Tensor,
scores: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if mask.size(2) > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
Wenet.
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, key_bias)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)[
:, :, :, : x.size(-1) // 2 + 1
] # only keep the positions from 0 to time2
return x
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
if matrix_ac.shape != matrix_bd.shape:
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache

View File

@@ -0,0 +1,145 @@
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""ConvolutionModule definition."""
from typing import Tuple
import torch
from torch import nn
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
super().__init__()
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
def forward(
self,
x: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
cache: torch.Tensor = torch.zeros((0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) # (#batch, channels, time)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
if self.lorder > 0:
if cache.size(2) == 0: # cache_t == 0
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
else:
assert cache.size(0) == x.size(0) # equal batch
assert cache.size(1) == x.size(1) # equal channel
x = torch.cat((cache, x), dim=2)
assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
return x.transpose(1, 2), new_cache

View File

@@ -0,0 +1,396 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Decoder definition."""
from typing import Tuple, List, Optional
import torch
import torch.utils.checkpoint as ckpt
import logging
from cosyvoice.transformer.decoder_layer import DecoderLayer
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
from cosyvoice.utils.class_utils import (
COSYVOICE_EMB_CLASSES,
COSYVOICE_ATTENTION_CLASSES,
COSYVOICE_ACTIVATION_CLASSES,
)
from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
class TransformerDecoder(torch.nn.Module):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the hidden units number of position-wise feedforward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
src_attention: if false, encoder-decoder cross attention is not
applied, such as CIF model
key_bias: whether use bias in attention.linear_k, False for whisper models.
gradient_checkpointing: rerunning a forward-pass segment for each
checkpointed segment during backward.
tie_word_embedding: Tie or clone module weights depending of whether we are
using TorchScript or not
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
src_attention: bool = True,
key_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
):
super().__init__()
attention_dim = encoder_output_size
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
self.embed = torch.nn.Sequential(
torch.nn.Identity() if input_layer == "no_pos" else
torch.nn.Embedding(vocab_size, attention_dim),
COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
positional_dropout_rate),
)
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
self.use_output_layer = use_output_layer
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = torch.nn.Identity()
self.num_blocks = num_blocks
self.decoders = torch.nn.ModuleList([
DecoderLayer(
attention_dim,
COSYVOICE_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim,
self_attention_dropout_rate, key_bias),
COSYVOICE_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim, src_attention_dropout_rate,
key_bias) if src_attention else None,
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate, activation),
dropout_rate,
normalize_before,
) for _ in range(self.num_blocks)
])
self.gradient_checkpointing = gradient_checkpointing
self.tie_word_embedding = tie_word_embedding
def forward(
self,
memory: torch.Tensor,
memory_mask: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
r_ys_in_pad: torch.Tensor = torch.empty(0),
reverse_weight: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: not used in transformer decoder, in order to unify api
with bidirectional decoder
reverse_weight: not used in transformer decoder, in order to unify
api with bidirectional decode
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out,
vocab_size) if use_output_layer is True,
torch.tensor(0.0), in order to unify api with bidirectional decoder
olens: (batch, )
NOTE(xcsong):
We pass the `__call__` method of the modules instead of `forward` to the
checkpointing API because `__call__` attaches all the hooks of the module.
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
"""
tgt = ys_in_pad
maxlen = tgt.size(1)
# tgt_mask: (B, 1, L)
tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
tgt_mask = tgt_mask.to(tgt.device)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1),
device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
x, _ = self.embed(tgt)
if self.gradient_checkpointing and self.training:
x = self.forward_layers_checkpointed(x, tgt_mask, memory,
memory_mask)
else:
x = self.forward_layers(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.use_output_layer:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, torch.tensor(0.0), olens
def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor) -> torch.Tensor:
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
memory_mask)
return x
@torch.jit.ignore(drop=True)
def forward_layers_checkpointed(self, x: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor) -> torch.Tensor:
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
layer.__call__, x, tgt_mask, memory, memory_mask)
return x
def forward_one_step(
self,
memory: torch.Tensor,
memory_mask: torch.Tensor,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
This is only used for decoding.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x, _ = self.embed(tgt)
new_cache = []
for i, decoder in enumerate(self.decoders):
if cache is None:
c = None
else:
c = cache[i]
x, tgt_mask, memory, memory_mask = decoder(x,
tgt_mask,
memory,
memory_mask,
cache=c)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.use_output_layer:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
def tie_or_clone_weights(self, jit_mode: bool = True):
"""Tie or clone module weights (between word_emb and output_layer)
depending of whether we are using TorchScript or not"""
if not self.use_output_layer:
return
if jit_mode:
logging.info("clone emb.weight to output.weight")
self.output_layer.weight = torch.nn.Parameter(
self.embed[0].weight.clone())
else:
logging.info("tie emb.weight with output.weight")
self.output_layer.weight = self.embed[0].weight
if getattr(self.output_layer, "bias", None) is not None:
self.output_layer.bias.data = torch.nn.functional.pad(
self.output_layer.bias.data,
(
0,
self.output_layer.weight.shape[0] -
self.output_layer.bias.shape[0],
),
"constant",
0,
)
class BiTransformerDecoder(torch.nn.Module):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the hidden units number of position-wise feedforward
num_blocks: the number of decoder blocks
r_num_blocks: the number of right to left decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
key_bias: whether use bias in attention.linear_k, False for whisper models.
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
r_num_blocks: int = 0,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
key_bias: bool = True,
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
):
super().__init__()
self.tie_word_embedding = tie_word_embedding
self.left_decoder = TransformerDecoder(
vocab_size,
encoder_output_size,
attention_heads,
linear_units,
num_blocks,
dropout_rate,
positional_dropout_rate,
self_attention_dropout_rate,
src_attention_dropout_rate,
input_layer,
use_output_layer,
normalize_before,
key_bias=key_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding)
self.right_decoder = TransformerDecoder(
vocab_size,
encoder_output_size,
attention_heads,
linear_units,
r_num_blocks,
dropout_rate,
positional_dropout_rate,
self_attention_dropout_rate,
src_attention_dropout_rate,
input_layer,
use_output_layer,
normalize_before,
key_bias=key_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding)
def forward(
self,
memory: torch.Tensor,
memory_mask: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
r_ys_in_pad: torch.Tensor,
reverse_weight: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
used for right to left decoder
reverse_weight: used for right to left decoder
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out,
vocab_size) if use_output_layer is True,
r_x: x: decoded token score (right to left decoder)
before softmax (batch, maxlen_out, vocab_size)
if use_output_layer is True,
olens: (batch, )
"""
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
ys_in_lens)
r_x = torch.tensor(0.0)
if reverse_weight > 0.0:
r_x, _, olens = self.right_decoder(memory, memory_mask,
r_ys_in_pad, ys_in_lens)
return l_x, r_x, olens
def forward_one_step(
self,
memory: torch.Tensor,
memory_mask: torch.Tensor,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
This is only used for decoding.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
tgt_mask, cache)
def tie_or_clone_weights(self, jit_mode: bool = True):
"""Tie or clone module weights (between word_emb and output_layer)
depending of whether we are using TorchScript or not"""
self.left_decoder.tie_or_clone_weights(jit_mode)
self.right_decoder.tie_or_clone_weights(jit_mode)

View File

@@ -0,0 +1,132 @@
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder self-attention layer definition."""
from typing import Optional, Tuple
import torch
from torch import nn
class DecoderLayer(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Inter-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
If `None` is passed, Inter-attention is not used, such as
CIF, GPT, and other decoder only model.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
"""
def __init__(
self,
size: int,
self_attn: nn.Module,
src_attn: Optional[nn.Module],
feed_forward: nn.Module,
dropout_rate: float,
normalize_before: bool = True,
):
"""Construct an DecoderLayer object."""
super().__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-5)
self.norm2 = nn.LayerNorm(size, eps=1e-5)
self.norm3 = nn.LayerNorm(size, eps=1e-5)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
def forward(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
cache: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor
(#batch, maxlen_out).
memory (torch.Tensor): Encoded memory
(#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask
(#batch, maxlen_in).
cache (torch.Tensor): cached tensors.
(#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = tgt_mask[:, -1:, :]
x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
if not self.normalize_before:
x = self.norm1(x)
if self.src_attn is not None:
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask)[0])
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask

View File

@@ -0,0 +1,293 @@
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Positonal Encoding Module."""
import math
from typing import Tuple, Union
import torch
import torch.nn.functional as F
import numpy as np
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
"""
def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int = 5000,
reverse: bool = False):
"""Construct an PositionalEncoding object."""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.max_len = max_len
self.pe = torch.zeros(self.max_len, self.d_model)
position = torch.arange(0, self.max_len,
dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.d_model))
self.pe[:, 0::2] = torch.sin(position * div_term)
self.pe[:, 1::2] = torch.cos(position * div_term)
self.pe = self.pe.unsqueeze(0)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
offset (int, torch.tensor): position offset
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
torch.Tensor: for compatibility to RelPositionalEncoding
"""
self.pe = self.pe.to(x.device)
pos_emb = self.position_encoding(offset, x.size(1), False)
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self,
offset: Union[int, torch.Tensor],
size: int,
apply_dropout: bool = True) -> torch.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int or torch.tensor): start offset
size (int): required size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
# How to subscript a Union type:
# https://github.com/pytorch/pytorch/issues/69434
if isinstance(offset, int):
assert offset + size <= self.max_len
pos_emb = self.pe[:, offset:offset + size]
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
assert offset + size <= self.max_len
pos_emb = self.pe[:, offset:offset + size]
else: # for batched streaming decoding on GPU
assert torch.max(offset) + size <= self.max_len
index = offset.unsqueeze(1) + \
torch.arange(0, size).to(offset.device) # B X T
flag = index > 0
# remove negative offset
index = index * flag
pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
if apply_dropout:
pos_emb = self.dropout(pos_emb)
return pos_emb
class RelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
"""Initialize class."""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.pe = self.pe.to(x.device)
x = x * self.xscale
pos_emb = self.position_encoding(offset, x.size(1), False)
return self.dropout(x), self.dropout(pos_emb)
class WhisperPositionalEncoding(PositionalEncoding):
""" Sinusoids position encoding used in openai-whisper.encoder
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
super().__init__(d_model, dropout_rate, max_len)
self.xscale = 1.0
log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment *
torch.arange(d_model // 2))
scaled_time = torch.arange(max_len)[:, np.newaxis] * \
inv_timescales[np.newaxis, :]
pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
delattr(self, "pe")
self.register_buffer("pe", pe.unsqueeze(0))
class LearnablePositionalEncoding(PositionalEncoding):
""" Learnable position encoding used in openai-whisper.decoder
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
super().__init__(d_model, dropout_rate, max_len)
# NOTE(xcsong): overwrite self.pe & self.xscale
self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
self.xscale = 1.0
class NoPositionalEncoding(torch.nn.Module):
""" No position encoding
"""
def __init__(self, d_model: int, dropout_rate: float):
super().__init__()
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
def forward(self,
x: torch.Tensor,
offset: Union[int, torch.Tensor] = 0) \
-> Tuple[torch.Tensor, torch.Tensor]:
""" Just return zero vector for interface compatibility
"""
pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
return self.dropout(x), pos_emb
def position_encoding(self, offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
return torch.zeros(1, size, self.d_model)
class EspnetRelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Construct an PositionalEncoding object."""
super(EspnetRelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
return self.dropout(x), self.dropout(pos_emb)
def position_encoding(self,
offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int or torch.tensor): start offset
size (int): required size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
]
return pos_emb

View File

@@ -0,0 +1,472 @@
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
from typing import Tuple
import torch
import torch.utils.checkpoint as ckpt
from cosyvoice.transformer.convolution import ConvolutionModule
from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
from cosyvoice.utils.class_utils import (
COSYVOICE_EMB_CLASSES,
COSYVOICE_SUBSAMPLE_CLASSES,
COSYVOICE_ATTENTION_CLASSES,
COSYVOICE_ACTIVATION_CLASSES,
)
from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.mask import add_optional_chunk_mask
class BaseEncoder(torch.nn.Module):
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
gradient_checkpointing: bool = False,
):
"""
Args:
input_size (int): input dim
output_size (int): dimension of attention
attention_heads (int): the number of heads of multi head attention
linear_units (int): the hidden units number of position-wise feed
forward
num_blocks (int): the number of decoder blocks
dropout_rate (float): dropout rate
attention_dropout_rate (float): dropout rate in attention
positional_dropout_rate (float): dropout rate after adding
positional encoding
input_layer (str): input layer type.
optional [linear, conv2d, conv2d6, conv2d8]
pos_enc_layer_type (str): Encoder positional encoding layer type.
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
normalize_before (bool):
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
static_chunk_size (int): chunk size for static chunk training and
decoding
use_dynamic_chunk (bool): whether use dynamic chunk size for
training or not, You can only use fixed chunk(chunk_size > 0)
or dyanmic chunk size(use_dynamic_chunk = True)
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
dynamic chunk training
key_bias: whether use bias in attention.linear_k, False for whisper models.
gradient_checkpointing: rerunning a forward-pass segment for each
checkpointed segment during backward.
"""
super().__init__()
self._output_size = output_size
self.global_cmvn = global_cmvn
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
input_size,
output_size,
dropout_rate,
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
positional_dropout_rate),
)
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
self.gradient_checkpointing = gradient_checkpointing
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
NOTE(xcsong):
We pass the `__call__` method of the modules instead of `forward` to the
checkpointing API because `__call__` attaches all the hooks of the module.
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
if self.gradient_checkpointing and self.training:
xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
mask_pad)
else:
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs
@torch.jit.ignore(drop=True)
def forward_layers_checkpointed(self, xs: torch.Tensor,
chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
for layer in self.encoders:
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
chunk_masks, pos_emb,
mask_pad)
return xs
def forward_chunk(
self,
xs: torch.Tensor,
offset: int,
required_cache_size: int,
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
assert xs.size(0) == 1
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
chunk_size = xs.size(1)
attention_key_size = cache_t1 + chunk_size
pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
size=attention_key_size)
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = attention_key_size
else:
next_cache_start = max(attention_key_size - required_cache_size, 0)
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoders):
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
xs,
att_mask,
pos_emb,
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
if self.normalize_before:
xs = self.after_norm(xs)
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache = torch.cat(r_att_cache, dim=0)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
return (xs, r_att_cache, r_cnn_cache)
def forward_chunk_by_chunk(
self,
xs: torch.Tensor,
decoding_chunk_size: int,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" Forward input chunk by chunk with chunk_size like a streaming
fashion
Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling
However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (torch.Tensor): (1, max_len, dim)
chunk_size (int): decoding chunk size
"""
assert decoding_chunk_size > 0
# The model is trained by static or dynamic chunk
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
subsampling = self.embed.subsampling_rate
context = self.embed.right_context + 1 # Add current frame
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1)
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
outputs = []
offset = 0
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
# Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, att_cache,
cnn_cache) = self.forward_chunk(chunk_xs, offset,
required_cache_size, att_cache,
cnn_cache)
outputs.append(y)
offset += y.size(1)
ys = torch.cat(outputs, 1)
masks = torch.ones((1, 1, ys.size(1)),
device=ys.device,
dtype=torch.bool)
return ys, masks
class TransformerEncoder(BaseEncoder):
"""Transformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
key_bias: bool = True,
selfattention_layer_type: str = "selfattn",
activation_type: str = "relu",
gradient_checkpointing: bool = False,
):
""" Construct TransformerEncoder
See Encoder for the meaning of each parameter.
"""
super().__init__(input_size, output_size, attention_heads,
linear_units, num_blocks, dropout_rate,
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing)
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
output_size,
attention_dropout_rate,
key_bias),
PositionwiseFeedForward(output_size, linear_units,
dropout_rate, activation),
dropout_rate, normalize_before) for _ in range(num_blocks)
])
class ConformerEncoder(BaseEncoder):
"""Conformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "rel_pos",
normalize_before: bool = True,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
positionwise_conv_kernel_size: int = 1,
macaron_style: bool = True,
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
cnn_module_kernel: int = 15,
causal: bool = False,
cnn_module_norm: str = "batch_norm",
key_bias: bool = True,
gradient_checkpointing: bool = False,
):
"""Construct ConformerEncoder
Args:
input_size to use_dynamic_chunk, see in BaseEncoder
positionwise_conv_kernel_size (int): Kernel size of positionwise
conv1d layer.
macaron_style (bool): Whether to use macaron style for
positionwise layer.
selfattention_layer_type (str): Encoder attention layer type,
the parameter has no effect now, it's just for configure
compatibility.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): whether to use causal convolution or not.
key_bias: whether use bias in attention.linear_k, False for whisper models.
"""
super().__init__(input_size, output_size, attention_heads,
linear_units, num_blocks, dropout_rate,
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing)
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
# self-attention module definition
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
key_bias,
)
# feed-forward module definition
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args),
PositionwiseFeedForward(*positionwise_layer_args),
PositionwiseFeedForward(
*positionwise_layer_args) if macaron_style else None,
ConvolutionModule(
*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
) for _ in range(num_blocks)
])

View File

@@ -0,0 +1,236 @@
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder self-attention layer definition."""
from typing import Optional, Tuple
import torch
from torch import nn
class TransformerEncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
"""
def __init__(
self,
size: int,
self_attn: torch.nn.Module,
feed_forward: torch.nn.Module,
dropout_rate: float,
normalize_before: bool = True,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-5)
self.norm2 = nn.LayerNorm(size, eps=1e-5)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, timetime),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): just for interface compatibility
to ConformerEncoderLayer
mask_pad (torch.Tensor): does not used in transformer layer,
just for unified api with conformer.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2), not used here, it's for interface
compatibility to ConformerEncoderLayer.
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
return x, mask, new_att_cache, fake_cnn_cache
class ConformerEncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
"""
def __init__(
self,
size: int,
self_attn: torch.nn.Module,
feed_forward: Optional[nn.Module] = None,
feed_forward_macaron: Optional[nn.Module] = None,
conv_module: Optional[nn.Module] = None,
dropout_rate: float = 0.1,
normalize_before: bool = True,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
self.norm_final = nn.LayerNorm(
size, eps=1e-5) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, timetime),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
att_cache)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
return x, mask, new_att_cache, new_cnn_cache

View File

@@ -0,0 +1,96 @@
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Label smoothing module."""
import torch
from torch import nn
class LabelSmoothingLoss(nn.Module):
"""Label-smoothing loss.
In a standard CE loss, the label's data distribution is:
[0,1,2] ->
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
In the smoothing version CE Loss,some probabilities
are taken from the true label prob (1.0) and are divided
among other labels.
e.g.
smoothing=0.1
[0,1,2] ->
[
[0.9, 0.05, 0.05],
[0.05, 0.9, 0.05],
[0.05, 0.05, 0.9],
]
Args:
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool):
normalize loss by sequence length if True
normalize loss by batch size if False
"""
def __init__(self,
size: int,
padding_idx: int,
smoothing: float,
normalize_length: bool = False):
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = nn.KLDivLoss(reduction="none")
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.normalize_length = normalize_length
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute loss between x and target.
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (torch.Tensor): prediction (batch, seqlen, class)
target (torch.Tensor):
target signal masked with self.padding_id (batch, seqlen)
Returns:
loss (torch.Tensor) : The KL loss, scalar float value
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
# use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT
true_dist = torch.zeros_like(x)
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom

View File

@@ -0,0 +1,115 @@
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positionwise feed forward layer definition."""
import torch
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def __init__(
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.activation = activation
self.dropout = torch.nn.Dropout(dropout_rate)
self.w_2 = torch.nn.Linear(hidden_units, idim)
def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)
"""
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
class MoEFFNLayer(torch.nn.Module):
"""
Mixture of expert with Positionwise feed forward layer
See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
The output dim is same with the input dim.
Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
Args:
n_expert: number of expert.
n_expert_per_token: The actual number of experts used for each frame
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def __init__(
self,
n_expert: int,
n_expert_per_token: int,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
):
super(MoEFFNLayer, self).__init__()
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
self.experts = torch.nn.ModuleList(
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
activation) for _ in range(n_expert))
self.n_expert_per_token = n_expert_per_token
def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Foward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)
"""
B, L, D = xs.size(
) # batch size, sequence length, embedding dimension (idim)
xs = xs.view(-1, D) # (B*L, D)
router = self.gate(xs) # (B*L, n_expert)
logits, indices = torch.topk(
router, self.n_expert_per_token
) # probs:(B*L, n_expert), indices: (B*L, n_expert)
weights = torch.nn.functional.softmax(
logits, dim=1,
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
output = torch.zeros_like(xs) # (B*L, D)
for i, expert in enumerate(self.experts):
mask = indices == i
batch_idx, ith_expert = torch.where(mask)
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
xs[batch_idx])
return output.view(B, L, D)

View File

@@ -0,0 +1,383 @@
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Subsampling layer definition."""
from typing import Tuple, Union
import torch
class BaseSubsampling(torch.nn.Module):
def __init__(self):
super().__init__()
self.right_context = 0
self.subsampling_rate = 1
def position_encoding(self, offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
return self.pos_enc.position_encoding(offset, size)
class EmbedinigNoSubsampling(BaseSubsampling):
"""Embedding input without subsampling
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
super().__init__()
self.embed = torch.nn.Embedding(idim, odim)
self.pos_enc = pos_enc_class
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.embed(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
class LinearNoSubsampling(BaseSubsampling):
"""Linear transform the input without subsampling
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an linear object."""
super().__init__()
self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim),
torch.nn.LayerNorm(odim, eps=1e-5),
torch.nn.Dropout(dropout_rate),
)
self.pos_enc = pos_enc_class
self.right_context = 0
self.subsampling_rate = 1
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.out(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
class Conv1dSubsampling2(BaseSubsampling):
"""Convolutional 1D subsampling (to 1/2 length).
It is designed for Whisper, ref:
https://github.com/openai/whisper/blob/main/whisper/model.py
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv1dSubsampling2 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
torch.nn.GELU(),
torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
torch.nn.GELU(),
)
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self.subsampling_rate = 2
# 4 = (3 - 1) * 1 + (3 - 1) * 1
self.right_context = 4
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
torch.Tensor: positional encoding
"""
time = x.size(1)
x = x.transpose(1, 2) # (b, f, t)
x = self.conv(x)
x = x.transpose(1, 2) # (b, t, f)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
class Conv2dSubsampling4(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling4 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self.subsampling_rate = 4
# 6 = (3 - 1) * 1 + (3 - 1) * 2
self.right_context = 6
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
class Conv2dSubsampling6(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling6 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
odim)
self.pos_enc = pos_enc_class
# 10 = (3 - 1) * 1 + (5 - 1) * 2
self.subsampling_rate = 6
self.right_context = 10
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
class Conv2dSubsampling8(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling8 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.linear = torch.nn.Linear(
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
self.pos_enc = pos_enc_class
self.subsampling_rate = 8
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
self.right_context = 14
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
class LegacyLinearNoSubsampling(BaseSubsampling):
"""Linear transform the input without subsampling
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an linear object."""
super().__init__()
self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim),
torch.nn.LayerNorm(odim, eps=1e-5),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
)
self.pos_enc = pos_enc_class
self.right_context = 0
self.subsampling_rate = 1
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.out(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask

View File

View File

@@ -0,0 +1,70 @@
# Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from cosyvoice.transformer.activation import Swish
from cosyvoice.transformer.subsampling import (
LinearNoSubsampling,
EmbedinigNoSubsampling,
Conv1dSubsampling2,
Conv2dSubsampling4,
Conv2dSubsampling6,
Conv2dSubsampling8,
)
from cosyvoice.transformer.embedding import (PositionalEncoding,
RelPositionalEncoding,
WhisperPositionalEncoding,
LearnablePositionalEncoding,
NoPositionalEncoding)
from cosyvoice.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
COSYVOICE_ACTIVATION_CLASSES = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": getattr(torch.nn, "SiLU", Swish),
"gelu": torch.nn.GELU,
}
COSYVOICE_SUBSAMPLE_CLASSES = {
"linear": LinearNoSubsampling,
"linear_legacy": LegacyLinearNoSubsampling,
"embed": EmbedinigNoSubsampling,
"conv1d2": Conv1dSubsampling2,
"conv2d": Conv2dSubsampling4,
"conv2d6": Conv2dSubsampling6,
"conv2d8": Conv2dSubsampling8,
'paraformer_dummy': torch.nn.Identity
}
COSYVOICE_EMB_CLASSES = {
"embed": PositionalEncoding,
"abs_pos": PositionalEncoding,
"rel_pos": RelPositionalEncoding,
"rel_pos_espnet": EspnetRelPositionalEncoding,
"no_pos": NoPositionalEncoding,
"abs_pos_whisper": WhisperPositionalEncoding,
"embed_learnable_pe": LearnablePositionalEncoding,
}
COSYVOICE_ATTENTION_CLASSES = {
"selfattn": MultiHeadedAttention,
"rel_selfattn": RelPositionMultiHeadedAttention,
}

93
cosyvoice/utils/common.py Normal file
View File

@@ -0,0 +1,93 @@
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Unility functions for Transformer."""
from typing import List
import torch
IGNORE_ID = -1
def pad_list(xs: List[torch.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
max_len = max([len(item) for item in xs])
batchs = len(xs)
ndim = xs[0].ndim
if ndim == 1:
pad_res = torch.zeros(batchs,
max_len,
dtype=xs[0].dtype,
device=xs[0].device)
elif ndim == 2:
pad_res = torch.zeros(batchs,
max_len,
xs[0].shape[1],
dtype=xs[0].dtype,
device=xs[0].device)
elif ndim == 3:
pad_res = torch.zeros(batchs,
max_len,
xs[0].shape[1],
xs[0].shape[2],
dtype=xs[0].dtype,
device=xs[0].device)
else:
raise ValueError(f"Unsupported ndim: {ndim}")
pad_res.fill_(pad_value)
for i in range(batchs):
pad_res[i, :len(xs[i])] = xs[i]
return pad_res
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
ignore_label: int) -> torch.Tensor:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
torch.Tensor: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = torch.sum(mask)
return (numerator / denominator).detach()

110
cosyvoice/utils/executor.py Normal file
View File

@@ -0,0 +1,110 @@
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import nullcontext
import os
import torch
import torch.distributed as dist
from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
class Executor:
def __init__(self):
self.step = 0
self.epoch = 0
self.rank = int(os.environ.get('RANK', 0))
self.device = torch.device('cuda:{}'.format(self.rank))
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join):
''' Train one epoch
'''
lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
logging.info('using accumulate grad, new batch size is {} times'
' larger than before'.format(info_dict['accum_grad']))
# A context manager to be used in conjunction with an instance of
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
model.train()
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
with model_context():
for batch_idx, batch_dict in enumerate(train_data_loader):
info_dict["tag"] = "TRAIN"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
if cosyvoice_join(group_join, info_dict):
break
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
context = model.no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
else:
context = nullcontext
with context():
info_dict = batch_forward(model, batch_dict, info_dict)
info_dict = batch_backward(model, info_dict)
info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
log_per_step(writer, info_dict)
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and (batch_idx + 1) % info_dict["accum_grad"] == 0:
dist.barrier()
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
model.train()
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
self.step += 1
dist.barrier()
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
@torch.inference_mode()
def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
''' Cross validation on
'''
logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
model.eval()
total_num_utts, total_loss_dict = 0, {} # avoid division by 0
for batch_idx, batch_dict in enumerate(cv_data_loader):
info_dict["tag"] = "CV"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
num_utts = len(batch_dict["utts"])
total_num_utts += num_utts
info_dict = batch_forward(model, batch_dict, info_dict)
for k, v in info_dict['loss_dict'].items():
if k not in total_loss_dict:
total_loss_dict[k] = []
total_loss_dict[k].append(v.item() * num_utts)
log_per_step(None, info_dict)
for k, v in total_loss_dict.items():
total_loss_dict[k] = sum(v) / total_num_utts
info_dict['loss_dict'] = total_loss_dict
log_per_save(writer, info_dict)
model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
save_model(model, model_name, info_dict)

View File

@@ -0,0 +1,41 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import torchaudio
def read_lists(list_file):
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
lists.append(line.strip())
return lists
def read_json_lists(list_file):
lists = read_lists(list_file)
results = {}
for fn in lists:
with open(fn, 'r', encoding='utf8') as fin:
results.update(json.load(fin))
return results
def load_wav(wav, target_sr):
speech, sample_rate = torchaudio.load(wav)
speech = speech.mean(dim=0, keepdim=True)
if sample_rate != target_sr:
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
return speech

View File

@@ -0,0 +1,120 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
# whether contain chinese character
def contains_chinese(text):
return bool(chinese_char_pattern.search(text))
# replace special symbol
def replace_corner_mark(text):
text = text.replace('²', '平方')
text = text.replace('³', '立方')
return text
# remove meaningless symbol
def remove_bracket(text):
text = text.replace('', '').replace('', '')
text = text.replace('', '').replace('', '')
text = text.replace('`', '').replace('`', '')
text = text.replace("——", " ")
return text
# spell Arabic numerals
def spell_out_number(text: str, inflect_parser):
new_text = []
st = None
for i, c in enumerate(text):
if not c.isdigit():
if st is not None:
num_str = inflect_parser.number_to_words(text[st: i])
new_text.append(num_str)
st = None
new_text.append(c)
else:
if st is None:
st = i
if st is not None and st < len(text):
num_str = inflect_parser.number_to_words(text[st:])
new_text.append(num_str)
return ''.join(new_text)
# split paragrah logic
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
# 2. cal sentence len according to lang
# 3. split sentence according to puncatation
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
def calc_utt_length(_text: str):
if lang == "zh":
return len(_text)
else:
return len(tokenize(_text))
def should_merge(_text: str):
if lang == "zh":
return len(_text) < merge_len
else:
return len(tokenize(_text)) < merge_len
if lang == "zh":
pounc = ['', '', '', '', '', '.', '?', '!', ';']
else:
pounc = ['.', '?', '!', ';', ':']
if comma_split:
pounc.extend(['', ','])
st = 0
utts = []
for i, c in enumerate(text):
if c in pounc:
if len(text[st: i]) > 0:
utts.append(text[st: i] + c)
if i + 1 < len(text) and text[i + 1] in ['"', '']:
tmp = utts.pop(-1)
utts.append(tmp + text[i + 1])
st = i + 2
else:
st = i + 1
final_utts = []
cur_utt = ""
for utt in utts:
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
final_utts.append(cur_utt)
cur_utt = ""
cur_utt = cur_utt + utt
if len(cur_utt) > 0:
if should_merge(cur_utt) and len(final_utts) != 0:
final_utts[-1] = final_utts[-1] + cur_utt
else:
final_utts.append(cur_utt)
return final_utts
# remove blank between chinese character
def replace_blank(text: str):
out_str = []
for i, c in enumerate(text):
if c == " ":
if ((text[i + 1].isascii() and text[i + 1] != " ") and
(text[i - 1].isascii() and text[i - 1] != " ")):
out_str.append(c)
else:
out_str.append(c)
return "".join(out_str)

227
cosyvoice/utils/mask.py Normal file
View File

@@ -0,0 +1,227 @@
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
'''
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret = torch.ones(size, size, device=device, dtype=torch.bool)
return torch.tril(ret)
'''
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
arange = torch.arange(size, device=device)
mask = arange.expand(size, size)
arange = arange.unsqueeze(-1)
mask = mask <= arange
return mask
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
def add_optional_chunk_mask(xs: torch.Tensor,
masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int,
static_chunk_size: int,
num_decoding_left_chunks: int,
enable_full_context: bool = True):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
enable_full_context (bool):
True: chunk size is either [1, 25] or full context(max_len)
False: chunk size ~ U[1, 25]
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2 and enable_full_context:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask

View File

@@ -0,0 +1,717 @@
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Ximalaya Inc (Yuguang Yang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
# NeMo(https://github.com/NVIDIA/NeMo)
from typing import Union
import math
import warnings
import torch
from torch.optim.lr_scheduler import _LRScheduler
class WarmupLR(_LRScheduler):
"""The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following
difference:
NoamLR:
lr = optimizer.lr * model_size ** -0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
WarmupLR:
lr = optimizer.lr * warmup_step ** 0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
Note that the maximum lr equals to optimizer.lr in this scheduler.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: Union[int, float] = 25000,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
super().__init__(optimizer, last_epoch)
def __repr__(self):
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
def get_lr(self):
step_num = self.last_epoch + 1
if self.warmup_steps == 0:
return [lr * step_num**-0.5 for lr in self.base_lrs]
else:
return [
lr * self.warmup_steps**0.5 *
min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
for lr in self.base_lrs
]
def set_step(self, step: int):
self.last_epoch = step
class WarmupPolicy(_LRScheduler):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
"""
def __init__(self,
optimizer,
*,
warmup_steps=None,
warmup_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1):
assert not (warmup_steps is not None and warmup_ratio is not None),\
"Either use particular number of step or ratio"
assert warmup_ratio is None or max_steps is not None, \
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`.",
UserWarning,
stacklevel=2)
step = self.last_epoch
if step <= self.warmup_steps and self.warmup_steps > 0:
return self._get_warmup_lr(step)
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
def _get_warmup_lr(self, step):
lr_val = (step + 1) / (self.warmup_steps + 1)
return [initial_lr * lr_val for initial_lr in self.base_lrs]
def _get_lr(self, step):
"""Simple const lr policy"""
return self.base_lrs
class SquareRootConstantPolicy(_LRScheduler):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
"""
def __init__(self,
optimizer,
*,
constant_steps=None,
constant_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1):
assert not (constant_steps is not None
and constant_ratio is not None), \
"Either use particular number of step or ratio"
assert constant_ratio is None or max_steps is not None, \
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if constant_steps is not None:
self.constant_steps = constant_steps
elif constant_ratio is not None:
self.constant_steps = int(constant_ratio * max_steps)
else:
self.constant_steps = 0
self.constant_lr = 1 / (constant_steps**0.5)
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`.",
UserWarning,
stacklevel=2)
step = self.last_epoch
if step <= self.constant_steps:
return [self.constant_lr for _ in self.base_lrs]
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
def _get_lr(self, step):
"""Simple const lr policy"""
return self.base_lrs
class WarmupHoldPolicy(WarmupPolicy):
"""Variant of WarmupPolicy which maintains high
learning rate for a defined number of steps.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
hold_steps: Number of training steps to
hold the learning rate after warm up
hold_ratio: Ratio of hold steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
"""
def __init__(
self,
optimizer,
*,
warmup_steps=None,
warmup_ratio=None,
hold_steps=None,
hold_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
assert not (hold_steps is not None and hold_ratio is not None), \
"Either use particular number of step or ratio"
assert hold_ratio is None or max_steps is not None, \
"If there is a ratio, there should be a total steps"
self.min_lr = min_lr
self._last_warmup_lr = 0.0
# Necessary to duplicate as class attributes are hidden in inner class
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
if hold_steps is not None:
self.hold_steps = hold_steps + self.warmup_steps
elif hold_ratio is not None:
self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
else:
self.hold_steps = 0
super().__init__(
optimizer,
warmup_steps=warmup_steps,
warmup_ratio=warmup_ratio,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler,"
" "
"please use `get_last_lr()`.",
UserWarning,
stacklevel=2)
step = self.last_epoch
# Warmup phase
if step <= self.warmup_steps and self.warmup_steps > 0:
return self._get_warmup_lr(step)
# Hold phase
if (step >= self.warmup_steps) and (step < self.hold_steps):
return self.base_lrs
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
class WarmupAnnealHoldPolicy(_LRScheduler):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
min_lr: Minimum lr to hold the learning rate after decay at.
constant_steps: Number of steps to keep lr constant at.
constant_ratio: Ratio of steps to keep lr constant.
"""
def __init__(
self,
optimizer,
*,
warmup_steps=None,
warmup_ratio=None,
constant_steps=None,
constant_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
assert not (warmup_steps is not None
and warmup_ratio is not None), \
"Either use particular number of step or ratio"
assert not (constant_steps is not None
and constant_ratio is not None), \
"Either use constant_steps or constant_ratio"
assert warmup_ratio is None or max_steps is not None, \
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
if constant_steps is not None:
self.constant_steps = constant_steps
elif constant_ratio is not None:
self.constant_steps = int(constant_ratio * max_steps)
else:
self.constant_steps = 0
self.decay_steps = max_steps - (self.constant_steps +
self.warmup_steps)
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`.",
UserWarning,
stacklevel=2)
step = self.last_epoch
# Warmup steps
if self.warmup_steps > 0 and step <= self.warmup_steps:
return self._get_warmup_lr(step)
# Constant steps after warmup and decay
if self.constant_steps > 0 and (
self.warmup_steps + self.decay_steps) < step <= self.max_steps:
return self._get_constant_lr(step)
# Min lr after max steps of updates
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
def _get_warmup_lr(self, step):
lr_val = (step + 1) / (self.warmup_steps + 1)
return [initial_lr * lr_val for initial_lr in self.base_lrs]
def _get_constant_lr(self, step):
return [self.min_lr for _ in self.base_lrs]
def _get_lr(self, step):
"""Simple const lr policy"""
return self.base_lrs
def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
mult = ((max_steps - step) / max_steps)**0.5
out_lr = initial_lr * mult
out_lr = max(out_lr, min_lr)
return out_lr
def _square_annealing(initial_lr, step, max_steps, min_lr):
mult = ((max_steps - step) / max_steps)**2
out_lr = initial_lr * mult
out_lr = max(out_lr, min_lr)
return out_lr
def _cosine_annealing(initial_lr, step, max_steps, min_lr):
mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
out_lr = (initial_lr - min_lr) * mult + min_lr
return out_lr
def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step,
decay_steps, min_lr):
assert max_lr > min_lr
# Use linear warmup for the initial part.
if warmup_steps > 0 and step <= warmup_steps:
return max_lr * float(step) / float(warmup_steps)
# For any steps larger than `decay_steps`, use `min_lr`.
if step > warmup_steps + decay_steps:
return min_lr
# If we are done with the warmup period, use the decay style.
num_steps_ = step - warmup_steps
decay_steps_ = decay_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = max_lr - min_lr
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
return min_lr + coeff * delta_lr
def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
if cycle:
multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
decay_steps *= multiplier
else:
step = min(step, decay_steps)
p = step / decay_steps
lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
lr += min_lr
return lr
def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps,
decay_rate, min_lr):
# hold_steps = total number of steps
# to hold the LR, not the warmup + hold steps.
T_warmup_decay = max(1, warmup_steps**decay_rate)
T_hold_decay = max(1, (step - hold_steps)**decay_rate)
lr = (initial_lr * T_warmup_decay) / T_hold_decay
lr = max(lr, min_lr)
return lr
class SquareAnnealing(WarmupPolicy):
def __init__(self,
optimizer,
*,
max_steps,
min_lr=1e-5,
last_epoch=-1,
**kwargs):
super().__init__(optimizer=optimizer,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
**kwargs)
def _get_lr(self, step):
new_lrs = [
_square_annealing(
initial_lr=initial_lr,
step=step - self.warmup_steps,
max_steps=self.max_steps - self.warmup_steps,
min_lr=self.min_lr,
) for initial_lr in self.base_lrs
]
return new_lrs
class SquareRootAnnealing(WarmupPolicy):
def __init__(self,
optimizer,
*,
max_steps,
min_lr=0,
last_epoch=-1,
**kwargs):
super().__init__(optimizer=optimizer,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
**kwargs)
def _get_lr(self, step):
new_lrs = [
_squareroot_annealing(initial_lr=initial_lr,
step=step,
max_steps=self.max_steps,
min_lr=self.min_lr)
for initial_lr in self.base_lrs
]
return new_lrs
class CosineAnnealing(WarmupAnnealHoldPolicy):
def __init__(self,
optimizer,
*,
max_steps,
min_lr=0,
last_epoch=-1,
**kwargs):
super().__init__(optimizer=optimizer,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
**kwargs)
def _get_lr(self, step):
for initial_lr in self.base_lrs:
if initial_lr < self.min_lr:
raise ValueError(
f"{self} received an initial learning rate "
f"that was lower than the minimum learning rate.")
if self.constant_steps is None or self.constant_steps == 0:
new_lrs = [
_cosine_annealing(
initial_lr=initial_lr,
step=step - self.warmup_steps,
max_steps=self.max_steps - self.warmup_steps,
min_lr=self.min_lr,
) for initial_lr in self.base_lrs
]
else:
new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
return new_lrs
def _get_warmup_lr(self, step):
if self.constant_steps is None or self.constant_steps == 0:
return super()._get_warmup_lr(step)
else:
# Use linear warmup for the initial part.
return self._get_linear_warmup_with_cosine_annealing_lr(step)
def _get_constant_lr(self, step):
# Only called when `constant_steps` > 0.
return self._get_linear_warmup_with_cosine_annealing_lr(step)
def _get_linear_warmup_with_cosine_annealing_lr(self, step):
# Cosine Schedule for Megatron LM,
# slightly different warmup schedule + constant LR at the end.
new_lrs = [
_linear_warmup_with_cosine_annealing(
max_lr=self.base_lrs[0],
warmup_steps=self.warmup_steps,
step=step,
decay_steps=self.decay_steps,
min_lr=self.min_lr,
) for _ in self.base_lrs
]
return new_lrs
class NoamAnnealing(_LRScheduler):
def __init__(self,
optimizer,
*,
d_model,
warmup_steps=None,
warmup_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1):
self._normalize = d_model**(-0.5)
assert not (warmup_steps is not None
and warmup_ratio is not None), \
"Either use particular number of step or ratio"
assert warmup_ratio is None or max_steps is not None, \
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`.",
UserWarning,
stacklevel=2)
step = max(1, self.last_epoch)
for initial_lr in self.base_lrs:
if initial_lr < self.min_lr:
raise ValueError(
f"{self} received an initial learning rate "
f"that was lower than the minimum learning rate.")
new_lrs = [
self._noam_annealing(initial_lr=initial_lr, step=step)
for initial_lr in self.base_lrs
]
return new_lrs
def _noam_annealing(self, initial_lr, step):
if self.warmup_steps > 0:
mult = self._normalize * min(step**(-0.5),
step * (self.warmup_steps**(-1.5)))
else:
mult = self._normalize * step**(-0.5)
out_lr = initial_lr * mult
if step > self.warmup_steps:
out_lr = max(out_lr, self.min_lr)
return out_lr
class NoamHoldAnnealing(WarmupHoldPolicy):
def __init__(self,
optimizer,
*,
max_steps,
decay_rate=0.5,
min_lr=0.0,
last_epoch=-1,
**kwargs):
"""
From Nemo:
Implementation of the Noam Hold Annealing policy
from the SqueezeFormer paper.
Unlike NoamAnnealing, the peak learning rate
can be explicitly set for this scheduler.
The schedule first performs linear warmup,
then holds the peak LR, then decays with some schedule for
the remainder of the steps.
Therefore the min-lr is still dependent
on the hyper parameters selected.
It's schedule is determined by three factors-
Warmup Steps: Initial stage, where linear warmup
occurs uptil the peak LR is reached. Unlike NoamAnnealing,
the peak LR is explicitly stated here instead of a scaling factor.
Hold Steps: Intermediate stage, where the peak LR
is maintained for some number of steps. In this region,
the high peak LR allows the model to converge faster
if training is stable. However the high LR
may also cause instability during training.
Should usually be a significant fraction of training
steps (around 30-40% of the entire training steps).
Decay Steps: Final stage, where the LR rapidly decays
with some scaling rate (set by decay rate).
To attain Noam decay, use 0.5,
for Squeezeformer recommended decay, use 1.0.
The fast decay after prolonged high LR during
hold phase allows for rapid convergence.
References:
- [Squeezeformer:
An Efficient Transformer for Automatic Speech Recognition]
(https://arxiv.org/abs/2206.00888)
Args:
optimizer: Pytorch compatible Optimizer object.
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
hold_steps: Number of training steps to
hold the learning rate after warm up
hold_ratio: Ratio of hold steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
decay_rate: Float value describing the polynomial decay
after the hold period. Default value
of 0.5 corresponds to Noam decay.
min_lr: Minimum learning rate.
"""
self.decay_rate = decay_rate
super().__init__(optimizer=optimizer,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
**kwargs)
def _get_lr(self, step):
if self.warmup_steps is None or self.warmup_steps == 0:
raise ValueError(
"Noam scheduler cannot be used without warmup steps")
if self.hold_steps > 0:
hold_steps = self.hold_steps - self.warmup_steps
else:
hold_steps = 0
new_lrs = [
_noam_hold_annealing(
initial_lr,
step=step,
warmup_steps=self.warmup_steps,
hold_steps=hold_steps,
decay_rate=self.decay_rate,
min_lr=self.min_lr,
) for initial_lr in self.base_lrs
]
return new_lrs
def set_step(self, step: int):
self.last_epoch = step

View File

@@ -0,0 +1,286 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2023 Horizon Inc. (authors: Xingchen Song)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import logging
import os
import torch
import json
import re
import datetime
import yaml
import deepspeed
import torch.optim as optim
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
from cosyvoice.dataset.dataset import Dataset
from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing
def init_distributed(args):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
', rank {}, world_size {}'.format(rank, world_size))
if args.train_engine == 'torch_ddp':
torch.cuda.set_device(local_rank)
dist.init_process_group(args.dist_backend)
else:
deepspeed.init_distributed(dist_backend=args.dist_backend)
return world_size, local_rank, rank
def init_dataset_and_dataloader(args, configs):
train_dataset = Dataset(args.train_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=True, partition=True)
cv_dataset = Dataset(args.cv_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=False, partition=False)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
cv_data_loader = DataLoader(cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
def check_modify_and_save_config(args, configs):
if args.train_engine == "torch_ddp":
configs['train_conf']["dtype"] = 'fp32'
else:
with open(args.deepspeed_config, 'r') as fin:
ds_configs = json.load(fin)
if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
configs['train_conf']["dtype"] = "fp16"
elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
configs['train_conf']["dtype"] = "bf16"
else:
configs['train_conf']["dtype"] = "fp32"
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
# if use deepspeed, override ddp config
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
return configs
def wrap_cuda_model(args, model):
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
world_size = int(os.environ.get('WORLD_SIZE', 1))
if args.train_engine == "torch_ddp": # native pytorch ddp
assert (torch.cuda.is_available())
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
else:
if int(os.environ.get('RANK', 0)) == 0:
logging.info("Estimating model states memory needs (zero2)...")
estimate_zero2_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size)
return model
def init_optimizer_and_scheduler(args, configs, model):
if configs['train_conf']['optim'] == 'adam':
optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
elif configs['train_conf']['optim'] == 'adamw':
optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
else:
raise ValueError("unknown scheduler: " + configs['train_conf'])
# use deepspeed optimizer for speedup
if args.train_engine == "deepspeed":
def scheduler(opt):
return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
model, optimizer, _, scheduler = deepspeed.initialize(
args=args,
model=model,
optimizer=None,
lr_scheduler=scheduler,
model_parameters=model.parameters())
return model, optimizer, scheduler
def init_summarywriter(args):
writer = None
if int(os.environ.get('RANK', 0)) == 0:
os.makedirs(args.model_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
return writer
def save_model(model, model_name, info_dict):
rank = int(os.environ.get('RANK', 0))
model_dir = info_dict["model_dir"]
save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
if info_dict["train_engine"] == "torch_ddp":
if rank == 0:
torch.save(model.module.state_dict(), save_model_path)
else:
with torch.no_grad():
model.save_checkpoint(save_dir=model_dir,
tag=model_name,
client_state=info_dict)
if rank == 0:
info_path = re.sub('.pt$', '.yaml', save_model_path)
info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(info_dict)
fout.write(data)
logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
def cosyvoice_join(group_join, info_dict):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
if info_dict["batch_idx"] != 0:
# we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
try:
dist.monitored_barrier(group=group_join,
timeout=group_join.options._timeout)
return False
except RuntimeError as e:
logging.info("Detected uneven workload distribution: {}\n".format(e) +
"Break current worker to manually join all workers, " +
"world_size {}, current rank {}, current local_rank {}\n".
format(world_size, rank, local_rank))
return True
else:
return False
def batch_forward(model, batch, info_dict):
device = int(os.environ.get('LOCAL_RANK', 0))
dtype = info_dict["dtype"]
if dtype == "fp16":
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
else: # fp32
dtype = torch.float32
if info_dict['train_engine'] == 'torch_ddp':
autocast = nullcontext()
else:
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
with autocast:
info_dict['loss_dict'] = model(batch, device)
return info_dict
def batch_backward(model, info_dict):
if info_dict["train_engine"] == "deepspeed":
scaled_loss = model.backward(info_dict['loss_dict']['loss'])
else:
scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
scaled_loss.backward()
info_dict['loss_dict']['loss'] = scaled_loss
return info_dict
def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
grad_norm = 0.0
if info_dict['train_engine'] == "deepspeed":
info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
model.step()
grad_norm = model.get_global_grad_norm()
elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
if torch.isfinite(grad_norm):
optimizer.step()
optimizer.zero_grad()
scheduler.step()
info_dict["lr"] = optimizer.param_groups[0]['lr']
info_dict["grad_norm"] = grad_norm
return info_dict
def log_per_step(writer, info_dict):
tag = info_dict["tag"]
epoch = info_dict.get('epoch', 0)
step = info_dict["step"]
batch_idx = info_dict["batch_idx"]
loss_dict = info_dict['loss_dict']
rank = int(os.environ.get('RANK', 0))
# only rank 0 write to tensorboard to avoid multi-process write
if writer is not None:
if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
(info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
for k in ['epoch', 'lr', 'grad_norm']:
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
for k, v in loss_dict.items():
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
# TRAIN & CV, Shell log (stdout)
if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
for name, value in loss_dict.items():
log_str += '{} {:.6f} '.format(name, value)
if tag == "TRAIN":
log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
info_dict["lr"], info_dict['grad_norm'])
log_str += ' rank {}'.format(rank)
logging.debug(log_str)
def log_per_save(writer, info_dict):
tag = info_dict["tag"]
epoch = info_dict["epoch"]
step = info_dict["step"]
loss_dict = info_dict["loss_dict"]
lr = info_dict['lr']
rank = int(os.environ.get('RANK', 0))
logging.info(
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
if writer is not None:
for k in ['epoch', 'lr']:
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
for k, v in loss_dict.items():
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)

BIN
cross_lingual_prompt.wav Normal file

Binary file not shown.

View File

@@ -0,0 +1,197 @@
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1986]
__set_seed2: !apply:numpy.random.seed [1986]
__set_seed3: !apply:torch.manual_seed [1986]
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
# fixed params
sample_rate: 22050
text_encoder_input_size: 512
llm_input_size: 1024
llm_output_size: 1024
spk_embed_dim: 192
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm.TransformerLM
text_encoder_input_size: !ref <text_encoder_input_size>
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
text_token_size: 51866
speech_token_size: 4096
length_normalized_loss: True
lsm_weight: 0
spk_embed_dim: !ref <spk_embed_dim>
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
input_size: !ref <text_encoder_input_size>
output_size: 1024
attention_heads: 8
linear_units: 2048
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
use_cnn_module: False
macaron_style: False
use_dynamic_chunk: False
use_dynamic_left_chunk: False
static_chunk_size: 1
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
input_size: !ref <llm_input_size>
output_size: !ref <llm_output_size>
attention_heads: 8
linear_units: 2048
num_blocks: 7
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0
input_layer: 'linear_legacy'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
static_chunk_size: 1
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
input_size: 512
output_size: 80
spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel'
vocab_size: 4096
input_frame_rate: 50
only_mask_loss: True
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
output_size: 512
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 512
use_cnn_module: False
macaron_style: False
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
channels: 80
sampling_ratios: [1, 1, 1, 1]
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
in_channels: 240
n_spks: 1
spk_emb_dim: 80
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
in_channels: 320
out_channels: 80
channels: [256, 256]
dropout: 0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 12
num_heads: 8
act_fn: 'gelu'
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 8]
upsample_kernel_sizes: [16, 16]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
# processor functions
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:whisper.tokenizer.get_tokenizer
multilingual: True
num_languages: 100
language: 'en'
task: 'transcribe'
allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate>
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: 8000
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 12000
padding: !name:cosyvoice.dataset.processor.padding
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <compute_fbank>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
# train conf
train_conf:
optim: adam
optim_conf:
lr: 0.002 # change to 0.001 if you want to train flow from scratch
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: -1

View File

@@ -0,0 +1,197 @@
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1986]
__set_seed2: !apply:numpy.random.seed [1986]
__set_seed3: !apply:torch.manual_seed [1986]
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
# fixed params
sample_rate: 22050
text_encoder_input_size: 512
llm_input_size: 1024
llm_output_size: 1024
spk_embed_dim: 192
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm.TransformerLM
text_encoder_input_size: !ref <text_encoder_input_size>
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
text_token_size: 51866
speech_token_size: 4096
length_normalized_loss: True
lsm_weight: 0
spk_embed_dim: !ref <spk_embed_dim>
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
input_size: !ref <text_encoder_input_size>
output_size: 1024
attention_heads: 16
linear_units: 4096
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
use_cnn_module: False
macaron_style: False
use_dynamic_chunk: False
use_dynamic_left_chunk: False
static_chunk_size: 1
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
input_size: !ref <llm_input_size>
output_size: !ref <llm_output_size>
attention_heads: 16
linear_units: 4096
num_blocks: 14
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0
input_layer: 'linear_legacy'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
static_chunk_size: 1
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
input_size: 512
output_size: 80
spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel'
vocab_size: 4096
input_frame_rate: 50
only_mask_loss: True
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
output_size: 512
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 512
use_cnn_module: False
macaron_style: False
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
channels: 80
sampling_ratios: [1, 1, 1, 1]
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
in_channels: 240
n_spks: 1
spk_emb_dim: 80
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
in_channels: 320
out_channels: 80
channels: [256, 256]
dropout: 0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 12
num_heads: 8
act_fn: 'gelu'
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 8]
upsample_kernel_sizes: [16, 16]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
# processor functions
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:whisper.tokenizer.get_tokenizer
multilingual: True
num_languages: 100
language: 'en'
task: 'transcribe'
allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate>
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: 8000
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 2000
padding: !name:cosyvoice.dataset.processor.padding
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <compute_fbank>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
# train conf
train_conf:
optim: adam
optim_conf:
lr: 0.001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 2500
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: -1

View File

@@ -0,0 +1,42 @@
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 5,
"fp16": {
"enabled": false,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 256,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"bf16": {
"enabled": false
},
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": false,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients" : true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.001,
"weight_decay": 0.0001,
"torch_adam": true,
"adam_w_mode": true
}
}
}

View File

@@ -0,0 +1 @@
../../../cosyvoice

View File

@@ -0,0 +1,97 @@
#!/bin/bash
# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
remove_archive=true
shift
fi
if [ $# -ne 3 ]; then
echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
echo "With --remove-archive it will remove the archive after successfully un-tarring it."
echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
echo " train-clean-100, train-clean-360, train-other-500."
exit 1
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
echo "$0: no such directory $data"
exit 1
fi
part_ok=false
list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
for x in $list; do
if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
exit 1
fi
if [ -z "$url" ]; then
echo "$0: empty URL base."
exit 1
fi
if [ -f $data/LibriSpeech/$part/.complete ]; then
echo "$0: data part $part was already successfully extracted, nothing to do."
exit 0
fi
# sizes of the archive files in bytes. This is some older versions.
sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
# sizes_new is the archive file sizes of the final release. Some of these sizes are of
# things we probably won't download.
sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
if [ -f $data/$part.tar.gz ]; then
size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
size_ok=false
for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
if ! $size_ok; then
echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
echo "does not equal the size of one of the archives."
rm $data/$part.tar.gz
else
echo "$data/$part.tar.gz exists and appears to be complete."
fi
fi
if [ ! -f $data/$part.tar.gz ]; then
if ! which wget >/dev/null; then
echo "$0: wget is not installed."
exit 1
fi
full_url=$url/$part.tar.gz
echo "$0: downloading data from $full_url. This may take some time, please be patient."
if ! wget -P $data --no-check-certificate $full_url; then
echo "$0: error executing wget $full_url"
exit 1
fi
fi
if ! tar -C $data -xvzf $data/$part.tar.gz; then
echo "$0: error un-tarring archive $data/$part.tar.gz"
exit 1
fi
touch $data/LibriSpeech/$part/.complete
echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
if $remove_archive; then
echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
rm $data/$part.tar.gz
fi

View File

@@ -0,0 +1,51 @@
import argparse
import logging
import glob
import os
from tqdm import tqdm
logger = logging.getLogger()
def main():
wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
for wav in tqdm(wavs):
txt = wav.replace('.wav', '.normalized.txt')
if not os.path.exists(txt):
logger.warning('{} do not exsist'.format(txt))
continue
with open(txt) as f:
content = ''.join(l.replace('\n', '') for l in f.readline())
utt = os.path.basename(wav).replace('.wav', '')
spk = utt.split('_')[0]
utt2wav[utt] = wav
utt2text[utt] = content
utt2spk[utt] = spk
if spk not in spk2utt:
spk2utt[spk] = []
spk2utt[spk].append(utt)
with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
for k, v in utt2wav.items():
f.write('{} {}\n'.format(k, v))
with open('{}/text'.format(args.des_dir), 'w') as f:
for k, v in utt2text.items():
f.write('{} {}\n'.format(k, v))
with open('{}/utt2spk'.format(args.des_dir), 'w') as f:
for k, v in utt2spk.items():
f.write('{} {}\n'.format(k, v))
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
for k, v in spk2utt.items():
f.write('{} {}\n'.format(k, ' '.join(v)))
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--src_dir',
type=str)
parser.add_argument('--des_dir',
type=str)
args = parser.parse_args()
main()

View File

@@ -0,0 +1,3 @@
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../../:../../../third_party/AcademiCodec:../../../third_party/Matcha-TTS:$PYTHONPATH

View File

@@ -0,0 +1,105 @@
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
. ./path.sh || exit 1;
stage=-1
stop_stage=3
data_url=www.openslr.org/resources/60
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
pretrained_model_dir=../../../pretrained_models/CosyVoice-300M
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--src_dir data/$x \
--des_dir data/$x/parquet
done
fi
# inference
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
for mode in sft zero_shot; do
python cosyvoice/bin/inference.py --mode $mode \
--gpu 0 \
--config conf/cosyvoice.yaml \
--prompt_data data/test-clean/parquet/data.list \
--prompt_utt2data data/test-clean/parquet/utt2data.list \
--tts_text `pwd`/tts_text.json \
--llm_model $pretrained_model_dir/llm.pt \
--flow_model $pretrained_model_dir/flow.pt \
--hifigan_model $pretrained_model_dir/hift.pt \
--result_dir `pwd`/exp/cosyvoice/test-clean/$mode
done
fi
# train llm
export CUDA_VISIBLE_DEVICES="0,1,2,3"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
job_id=1986
dist_backend="nccl"
num_workers=2
prefetch=100
train_engine=torch_ddp
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
for model in llm; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
cosyvoice/bin/train.py \
--train_engine $train_engine \
--config conf/cosyvoice.yaml \
--train_data data/train.data.list \
--cv_data data/dev.data.list \
--model $model \
--checkpoint $pretrained_model_dir/$model.pt \
--model_dir `pwd`/exp/cosyvoice/$model/$train_engine \
--tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer
done
fi

View File

@@ -0,0 +1 @@
../../../tools

View File

@@ -0,0 +1,5 @@
{
"1089_134686_000002_000000": [
"hello, my name is Jack. What is your name?"
]
}

27
requirements.txt Normal file
View File

@@ -0,0 +1,27 @@
--extra-index-url https://download.pytorch.org/whl/cu118
conformer==0.3.2
deepspeed==0.14.2
diffusers==0.27.2
gdown==5.1.0
gradio==4.32.2
grpcio==1.57.0
grpcio-tools==1.57.0
hydra-core==1.3.2
HyperPyYAML==1.2.2
inflect==6.0.2
librosa==0.10.2
lightning==2.2.4
matplotlib==3.7.5
modelscope==1.15.0
networkx==3.1
omegaconf==2.3.0
onnxruntime-gpu==1.16.0
openai-whisper==20231117
protobuf==4.25
pydantic==2.7.0
rich==13.7.1
soundfile==0.12.1
tensorboard==2.14.0
torch==2.0.1
torchaudio==2.0.2
wget==3.2

15
runtime/python/Dockerfile Normal file
View File

@@ -0,0 +1,15 @@
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
ENV DEBIAN_FRONTEND=noninteractive
WORKDIR /opt/CosyVoice
RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list
RUN apt-get update -y
RUN apt-get -y install python3-dev cmake python3-pip git
# install torch takes a long time, cache it in case we may change requirements.txt
# RUN git clone --depth 1 https://github.com/FunAudioLLM/CosyVoice.git
ADD CosyVoice.tar .
RUN mv CosyVoice_dockerfile CosyVoice
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
RUN cd CosyVoice/runtime/python && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
CMD ["/bin/bash", "-c", "cd /opt/CosyVoice/CosyVoice/runtime/python && . ./path/sh && python3 server.py --port 50000 --max_conc 4 --model_dir speech_tts/CosyVoice-300M && sleep infinity"]

103
runtime/python/client.py Normal file
View File

@@ -0,0 +1,103 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/AcademiCodec'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
import logging
import argparse
import torchaudio
import cosyvoice_pb2
import cosyvoice_pb2_grpc
import grpc
import torch
import numpy as np
from cosyvoice.utils.file_utils import load_wav
def main():
with grpc.insecure_channel("{}:{}".format(args.host, args.port)) as channel:
stub = cosyvoice_pb2_grpc.CosyVoiceStub(channel)
request = cosyvoice_pb2.Request()
if args.mode == 'sft':
logging.info('send sft request')
sft_request = cosyvoice_pb2.sftRequest()
sft_request.spk_id = args.spk_id
sft_request.tts_text = args.tts_text
request.sft_request.CopyFrom(sft_request)
elif args.mode == 'zero_shot':
logging.info('send zero_shot request')
zero_shot_request = cosyvoice_pb2.zeroshotRequest()
zero_shot_request.tts_text = args.tts_text
zero_shot_request.prompt_text = args.prompt_text
prompt_speech = load_wav(args.prompt_wav, 16000)
zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
request.zero_shot_request.CopyFrom(zero_shot_request)
elif args.mode == 'cross_lingual':
logging.info('send cross_lingual request')
cross_lingual_request = cosyvoice_pb2.crosslingualRequest()
cross_lingual_request.tts_text = args.tts_text
prompt_speech = load_wav(args.prompt_wav, 16000)
cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
request.cross_lingual_request.CopyFrom(cross_lingual_request)
else:
logging.info('send instruct request')
instruct_request = cosyvoice_pb2.instructRequest()
instruct_request.tts_text = args.tts_text
instruct_request.spk_id = args.spk_id
instruct_request.instruct_text = args.instruct_text
request.instruct_request.CopyFrom(instruct_request)
response = stub.Inference(request)
logging.info('save response to {}'.format(args.tts_wav))
tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0)
torchaudio.save(args.tts_wav, tts_speech, target_sr)
logging.info('get response')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--host',
type=str,
default='0.0.0.0')
parser.add_argument('--port',
type=int,
default='50000')
parser.add_argument('--mode',
default='sft',
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
help='request mode')
parser.add_argument('--tts_text',
type=str,
default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?')
parser.add_argument('--spk_id',
type=str,
default='中文女')
parser.add_argument('--prompt_text',
type=str,
default='希望你以后能够做的比我还好呦。')
parser.add_argument('--prompt_wav',
type=str,
default='../../zero_shot_prompt.wav')
parser.add_argument('--instruct_text',
type=str,
default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
parser.add_argument('--tts_wav',
type=str,
default='demo.wav')
args = parser.parse_args()
prompt_sr, target_sr = 16000, 22050
main()

View File

@@ -0,0 +1,56 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
syntax = "proto3";
package cosyvoice;
option go_package = "protos/";
service CosyVoice{
rpc Inference(Request) returns (Response) {}
}
message Request{
oneof RequestPayload {
sftRequest sft_request = 1;
zeroshotRequest zero_shot_request = 2;
crosslingualRequest cross_lingual_request = 3;
instructRequest instruct_request = 4;
}
}
message sftRequest{
string spk_id = 1;
string tts_text = 2;
}
message zeroshotRequest{
string tts_text = 1;
string prompt_text = 2;
bytes prompt_audio = 3;
}
message crosslingualRequest{
string tts_text = 1;
bytes prompt_audio = 2;
}
message instructRequest{
string tts_text = 1;
string spk_id = 2;
string instruct_text = 3;
}
message Response{
bytes tts_audio = 1;
}

3
runtime/python/path.sh Normal file
View File

@@ -0,0 +1,3 @@
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../:../../third_party/AcademiCodec:../../third_party/Matcha-TTS:$PYTHONPATH

85
runtime/python/server.py Normal file
View File

@@ -0,0 +1,85 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/AcademiCodec'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from concurrent import futures
import argparse
import cosyvoice_pb2
import cosyvoice_pb2_grpc
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import grpc
import torch
import numpy as np
from cosyvoice.cli.cosyvoice import CosyVoice
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
def __init__(self, args):
self.cosyvoice = CosyVoice(args.model_dir)
logging.info('grpc service initialized')
def Inference(self, request, context):
if request.HasField('sft_request'):
logging.info('get sft inference request')
model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)
elif request.HasField('zero_shot_request'):
logging.info('get zero_shot inference request')
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, request.zero_shot_request.prompt_text, prompt_speech_16k)
elif request.HasField('cross_lingual_request'):
logging.info('get cross_lingual inference request')
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
else:
logging.info('get instruct inference request')
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
logging.info('send inference response')
response = cosyvoice_pb2.Response()
response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
return response
def main():
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))
grpcServer.start()
logging.info("server listening on 0.0.0.0:{}".format(args.port))
grpcServer.wait_for_termination()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port',
type=int,
default=50000)
parser.add_argument('--max_conc',
type=int,
default=4)
parser.add_argument('--model_dir',
type=str,
required=True,
default='speech_tts/CosyVoice-300M',
help='local path or modelscope repo id')
args = parser.parse_args()
main()

67
tools/extract_embedding.py Executable file
View File

@@ -0,0 +1,67 @@
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
import torchaudio
from tqdm import tqdm
import onnxruntime
import torchaudio.compliance.kaldi as kaldi
def main(args):
utt2wav, utt2spk = {}, {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
with open('{}/utt2spk'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2spk[l[0]] = l[1]
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
utt2embedding, spk2embedding = {}, {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
feat = kaldi.fbank(audio,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
utt2embedding[utt] = embedding
spk = utt2spk[utt]
if spk not in spk2embedding:
spk2embedding[spk] = []
spk2embedding[spk].append(embedding)
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
args = parser.parse_args()
main(args)

61
tools/extract_speech_token.py Executable file
View File

@@ -0,0 +1,61 @@
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import torch
from tqdm import tqdm
import onnxruntime
import numpy as np
import torchaudio
import whisper
def main(args):
utt2wav = {}
with open('{}/wav.scp'.format(args.dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ["CUDAExecutionProvider"]
ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
utt2speech_token = {}
for utt in tqdm(utt2wav.keys()):
audio, sample_rate = torchaudio.load(utt2wav[utt])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
if audio.shape[1] / 16000 > 30:
logging.warning('do not support extract speech token for audio longer than 30s')
speech_token = []
else:
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
utt2speech_token[utt] = speech_token
torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dir',
type=str)
parser.add_argument('--onnx_path',
type=str)
args = parser.parse_args()
main(args)

112
tools/make_parquet_list.py Executable file
View File

@@ -0,0 +1,112 @@
#!/usr/bin/env python3
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import json
from tqdm import tqdm
import pandas as pd
import multiprocessing
import time
import torch
def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
start_time = time.time()
data_list = []
for utt in tqdm(utt_list):
data = open(utt2wav[utt], 'rb').read()
data_list.append(data)
wav_list = [utt2wav[utt] for utt in utt_list]
text_list = [utt2text[utt] for utt in utt_list]
spk_list = [utt2spk[utt] for utt in utt_list]
uttembedding_list = [utt2embedding[utt] for utt in utt_list]
spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
speech_token_list = [utt2speech_token[utt] for utt in utt_list]
# 保存到parquet,utt2parquet_file,spk2parquet_file
df = pd.DataFrame()
df['utt'] = utt_list
df['wav'] = wav_list
df['audio_data'] = data_list
df['text'] = text_list
df['spk'] = spk_list
df['utt_embedding'] = uttembedding_list
df['spk_embedding'] = spkembedding_list
df['speech_token'] = speech_token_list
df.to_parquet(parquet_file)
with open(utt2parquet_file, 'w') as f:
json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
with open(spk2parquet_file, 'w') as f:
json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
logging.info('spend time {}'.format(time.time() - start_time))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_utts_per_parquet',
type=int,
default=1000,
help='num utts per parquet')
parser.add_argument('--num_processes',
type=int,
default=1,
help='num processes for make parquets')
parser.add_argument('--src_dir',
type=str)
parser.add_argument('--des_dir',
type=str)
args = parser.parse_args()
utt2wav, utt2text, utt2spk = {}, {}, {}
with open('{}/wav.scp'.format(args.src_dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2wav[l[0]] = l[1]
with open('{}/text'.format(args.src_dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2text[l[0]] = ' '.join(l[1:])
with open('{}/utt2spk'.format(args.src_dir)) as f:
for l in f:
l = l.replace('\n', '').split()
utt2spk[l[0]] = l[1]
utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
utts = list(utt2wav.keys())
# Using process pool to speedup
pool = multiprocessing.Pool(processes=args.num_processes)
parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
parquet_list.append(parquet_file)
utt2parquet_list.append(utt2parquet_file)
spk2parquet_list.append(spk2parquet_file)
pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
pool.close()
pool.join()
with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
for name in parquet_list:
f1.write(name + '\n')
for name in utt2parquet_list:
f2.write(name + '\n')
for name in spk2parquet_list:
f3.write(name + '\n')

186
webui.py Normal file
View File

@@ -0,0 +1,186 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/third_party/AcademiCodec'.format(ROOT_DIR))
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
import argparse
import gradio as gr
import numpy as np
import torch
import torchaudio
import random
import librosa
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.utils.file_utils import load_wav
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
def generate_seed():
seed = random.randint(1, 100000000)
return {
"__type__": "update",
"value": seed
}
def set_all_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
max_val = 0.8
def postprocess(speech, top_db=60, hop_length=220, win_length=440):
speech, _ = librosa.effects.trim(
speech, top_db=top_db,
frame_length=win_length,
hop_length=hop_length
)
if speech.abs().max() > max_val:
speech = speech / speech.abs().max() * max_val
speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
return speech
inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
instruct_dict = {'预训练音色': '1. 选择预训练音色\n2.点击生成音频按钮',
'3s极速复刻': '1. 选择prompt音频文件或录入prompt音频若同时提供优先选择prompt音频文件\n2. 输入prompt文本\n3.点击生成音频按钮',
'跨语种复刻': '1. 选择prompt音频文件或录入prompt音频若同时提供优先选择prompt音频文件\n2.点击生成音频按钮',
'自然语言控制': '1. 输入instruct文本\n2.点击生成音频按钮'}
def change_instruction(mode_checkbox_group):
return instruct_dict[mode_checkbox_group]
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed):
if prompt_wav_upload is not None:
prompt_wav = prompt_wav_upload
elif prompt_wav_record is not None:
prompt_wav = prompt_wav_record
else:
prompt_wav = None
# if instruct mode, please make sure that model is speech_tts/CosyVoice-300M-Instruct and not cross_lingual mode
if mode_checkbox_group in ['自然语言控制']:
if cosyvoice.frontend.instruct is False:
gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用speech_tts/CosyVoice-300M-Instruct模型'.format(args.model_dir))
return (target_sr, default_data)
if instruct_text == '':
gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
return (target_sr, default_data)
if prompt_wav is not None or prompt_text != '':
gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
# if cross_lingual mode, please make sure that model is speech_tts/CosyVoice-300M and tts_text prompt_text are different language
if mode_checkbox_group in ['跨语种复刻']:
if cosyvoice.frontend.instruct is True:
gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用speech_tts/CosyVoice-300M模型'.format(args.model_dir))
return (target_sr, default_data)
if instruct_text != '':
gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
if prompt_wav is None:
gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
return (target_sr, default_data)
gr.Info('您正在使用跨语种复刻模式, 请确保合成文本和prompt文本为不同语言')
# if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
if prompt_wav is None:
gr.Warning('prompt音频为空您是否忘记输入prompt音频')
return (target_sr, default_data)
if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
return (target_sr, default_data)
# sft mode only use sft_dropdown
if mode_checkbox_group in ['预训练音色']:
if instruct_text != '' or prompt_wav is not None or prompt_text != '':
gr.Info('您正在使用预训练音色模式prompt文本/prompt音频/instruct文本会被忽略')
# zero_shot mode only use prompt_wav prompt text
if mode_checkbox_group in ['3s极速复刻']:
if prompt_text == '':
gr.Warning('prompt文本为空您是否忘记输入prompt文本')
return (target_sr, default_data)
if instruct_text != '':
gr.Info('您正在使用3s极速复刻模式预训练音色/instruct文本会被忽略')
if mode_checkbox_group == '预训练音色':
logging.info('get sft inference request')
set_all_random_seed(seed)
output = cosyvoice.inference_sft(tts_text, sft_dropdown)
elif mode_checkbox_group == '3s极速复刻':
logging.info('get zero_shot inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed)
output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
elif mode_checkbox_group == '跨语种复刻':
logging.info('get cross_lingual inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed)
output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
else:
logging.info('get instruct inference request')
set_all_random_seed(seed)
output = cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text)
audio_data = output['tts_speech'].numpy().flatten()
return (target_sr, audio_data)
def main():
with gr.Blocks() as demo:
gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) 预训练模型 [CosyVoice-300M](https://www.modelscope.cn/models/speech_tts/CosyVoice-300M) [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/speech_tts/CosyVoice-300M-Instruct) [CosyVoice-300M-SFT](https://www.modelscope.cn/models/speech_tts/CosyVoice-300M-SFT)")
gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
with gr.Row():
mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
with gr.Column(scale=0.25):
seed_button = gr.Button(value="\U0001F3B2")
seed = gr.Number(value=0, label="随机推理种子")
with gr.Row():
prompt_wav_upload = gr.Audio(sources='upload', type='filepath', label='选择prompt音频文件注意采样率不低于16khz')
prompt_wav_record = gr.Audio(sources='microphone', type='filepath', label='录制prompt音频文件')
prompt_text = gr.Textbox(label="输入prompt文本", lines=1, placeholder="请输入prompt文本需与prompt音频内容一致暂时不支持自动识别...", value='')
instruct_text = gr.Textbox(label="输入instruct文本", lines=1, placeholder="请输入instruct文本.", value='')
generate_button = gr.Button("生成音频")
audio_output = gr.Audio(label="合成音频")
seed_button.click(generate_seed, inputs=[], outputs=seed)
generate_button.click(generate_audio,
inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed],
outputs=[audio_output])
mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
demo.queue(max_size=4, default_concurrency_limit=2)
demo.launch(server_port=args.port)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port',
type=int,
default=8000)
parser.add_argument('--model_dir',
type=str,
default='speech_tts/CosyVoice-300M',
help='local path or modelscope repo id')
args = parser.parse_args()
cosyvoice = CosyVoice(args.model_dir)
sft_spk = cosyvoice.list_avaliable_spks()
prompt_sr, target_sr = 16000, 22050
default_data = np.zeros(target_sr)
main()

BIN
zero_shot_prompt.wav Normal file

Binary file not shown.