mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
Compare commits
455 Commits
hengwu.zty
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f08872a82f | ||
|
|
f2ddcbe7f9 | ||
|
|
0d990d6074 | ||
|
|
c93d3dda01 | ||
|
|
84e41729ea | ||
|
|
f26cde56df | ||
|
|
66b80dbccb | ||
|
|
1822c5c908 | ||
|
|
1dcc59676f | ||
|
|
7fdd80dc64 | ||
|
|
f97d50d559 | ||
|
|
652132ebaa | ||
|
|
1ceb2b7b1e | ||
|
|
55e3e370a0 | ||
|
|
46788c7379 | ||
|
|
881177287c | ||
|
|
f88a14e41d | ||
|
|
dac6566fc3 | ||
|
|
cc91e40db8 | ||
|
|
ab7f1f4a86 | ||
|
|
e15222b17c | ||
|
|
cfa1c115b2 | ||
|
|
dd5cdb6ebf | ||
|
|
2d7ef0b719 | ||
|
|
ba5db602a9 | ||
|
|
5b94675f62 | ||
|
|
4c19646b9a | ||
|
|
63a06227d1 | ||
|
|
3b44913782 | ||
|
|
055f64d002 | ||
|
|
4d7295a9a7 | ||
|
|
8524c81acd | ||
|
|
a14e063ead | ||
|
|
2db78e7058 | ||
|
|
7538c6a73d | ||
|
|
823ae2c60d | ||
|
|
59cb2bf16c | ||
|
|
80bebb1978 | ||
|
|
bc34459bb8 | ||
|
|
9f27b42cd9 | ||
|
|
a7d6e2251a | ||
|
|
7baefaf0f2 | ||
|
|
ff0d05c380 | ||
|
|
f5816b4e51 | ||
|
|
8b54619760 | ||
|
|
2abd42220e | ||
|
|
2d6bb9bd80 | ||
|
|
0b80c0746a | ||
|
|
e98b828f33 | ||
|
|
4d4c787be0 | ||
|
|
781a49acb4 | ||
|
|
9476a063b3 | ||
|
|
3426ceb70f | ||
|
|
a460960ade | ||
|
|
f51f5c5c6a | ||
|
|
f11ba4024c | ||
|
|
089343ab0a | ||
|
|
0c50894d49 | ||
|
|
95d56cba64 | ||
|
|
095f7bad55 | ||
|
|
a6eb2c56da | ||
|
|
ca3b054a52 | ||
|
|
b02d7e61f7 | ||
|
|
6b6a5a7bd1 | ||
|
|
5640545406 | ||
|
|
5bc4b23f02 | ||
|
|
ebef63066f | ||
|
|
3298d6f3e3 | ||
|
|
f21c4764ec | ||
|
|
927addadd8 | ||
|
|
a051a09ba4 | ||
|
|
0c65d3c7ab | ||
|
|
56d9876037 | ||
|
|
b35ece675b | ||
|
|
59f02cb85d | ||
|
|
b4dd67a8af | ||
|
|
bfa835a74b | ||
|
|
622a3a19b0 | ||
|
|
d985100326 | ||
|
|
6816fc6a6f | ||
|
|
e8bf717333 | ||
|
|
fa2781405f | ||
|
|
cd26dd1932 | ||
|
|
6e01309e01 | ||
|
|
1fc8435146 | ||
|
|
a224be6117 | ||
|
|
33aee03ed5 | ||
|
|
8811e9f33a | ||
|
|
807bb6ee0b | ||
|
|
aceede59ba | ||
|
|
7cbd490253 | ||
|
|
a019a2504e | ||
|
|
f186ec3338 | ||
|
|
988d395162 | ||
|
|
4d60ff6abc | ||
|
|
be005c825f | ||
|
|
79116ac32e | ||
|
|
31a0adc73d | ||
|
|
482464ea27 | ||
|
|
444b7ff5df | ||
|
|
b207c60885 | ||
|
|
0b357ba25d | ||
|
|
0867ebcb8c | ||
|
|
52556a6de9 | ||
|
|
66ef5a097b | ||
|
|
cc1991870b | ||
|
|
8ded65e611 | ||
|
|
6971536358 | ||
|
|
86e7c2d731 | ||
|
|
8a4309d89c | ||
|
|
ad257b06e3 | ||
|
|
633b991290 | ||
|
|
e04699c6da | ||
|
|
73d261dd48 | ||
|
|
b7ec6c4678 | ||
|
|
f76f5abcc1 | ||
|
|
6b5eef62cc | ||
|
|
dc96e4c984 | ||
|
|
70991d7327 | ||
|
|
8c96081f94 | ||
|
|
dd2d926147 | ||
|
|
da41f6175b | ||
|
|
e3c2400abb | ||
|
|
a976519ada | ||
|
|
cf615011ce | ||
|
|
9ddb9e4a83 | ||
|
|
0a496c18f7 | ||
|
|
05bdf4c769 | ||
|
|
1850e2a56e | ||
|
|
47e4137651 | ||
|
|
0bc48c1180 | ||
|
|
62d082634e | ||
|
|
07cbc51cd1 | ||
|
|
d1c354eac7 | ||
|
|
1b8d194b67 | ||
|
|
b44f121102 | ||
|
|
dc196df940 | ||
|
|
178da09993 | ||
|
|
11515d0d5a | ||
|
|
5427c274e3 | ||
|
|
3387f07266 | ||
|
|
b048a2d6db | ||
|
|
8555549e88 | ||
|
|
3047591fad | ||
|
|
5a00aefa20 | ||
|
|
35abb1f3b5 | ||
|
|
21a5efd8ae | ||
|
|
44316c3475 | ||
|
|
116c99bf39 | ||
|
|
6eaef42126 | ||
|
|
525531d8a3 | ||
|
|
c788bca1a6 | ||
|
|
ff9694dc2c | ||
|
|
4505924608 | ||
|
|
46dfe0439b | ||
|
|
63856565f3 | ||
|
|
cc234bd322 | ||
|
|
98ef35b1c0 | ||
|
|
fca1df3ea7 | ||
|
|
c939c80480 | ||
|
|
5bf6befd70 | ||
|
|
1c7976779b | ||
|
|
a8e1774e82 | ||
|
|
1f50ae259b | ||
|
|
793a7fe6ad | ||
|
|
79b2ea818b | ||
|
|
7d9d84d32d | ||
|
|
9b052a94c4 | ||
|
|
6dd68b9d5e | ||
|
|
9f55c5af8f | ||
|
|
b6c5f9dfd2 | ||
|
|
cbfed4a9ee | ||
|
|
54d21b40f0 | ||
|
|
c3250c222f | ||
|
|
82219cdd27 | ||
|
|
4159a18469 | ||
|
|
3e12bb86bd | ||
|
|
cbfbe2bc33 | ||
|
|
3660da4a19 | ||
|
|
3c921daede | ||
|
|
68100c267a | ||
|
|
fbab274b6a | ||
|
|
97f0bc61cd | ||
|
|
afb1a70f7a | ||
|
|
5c77e40304 | ||
|
|
b4c4d848ca | ||
|
|
88f467a8ac | ||
|
|
038ff9f353 | ||
|
|
65ad448714 | ||
|
|
a96ae13616 | ||
|
|
587604b2b4 | ||
|
|
e97cd1b655 | ||
|
|
8d67d17f73 | ||
|
|
a442317d17 | ||
|
|
7f8bea2669 | ||
|
|
6d876f573c | ||
|
|
3770c1c8b1 | ||
|
|
2c193781cc | ||
|
|
efe1d15960 | ||
|
|
9ebcf7b1ad | ||
|
|
37e48dd318 | ||
|
|
c07cd3d730 | ||
|
|
36aec2c0f7 | ||
|
|
d71d790f55 | ||
|
|
e1ffb1e978 | ||
|
|
9fea0f0836 | ||
|
|
9dc559fc2a | ||
|
|
634edfadf0 | ||
|
|
b56dfa223d | ||
|
|
f0b8e892f6 | ||
|
|
cfc68f379c | ||
|
|
4951d2ad1a | ||
|
|
d9ffd592f6 | ||
|
|
7902d1c17f | ||
|
|
39ffc50dec | ||
|
|
08312f4c46 | ||
|
|
c6d8737336 | ||
|
|
a22873e360 | ||
|
|
c97b445df4 | ||
|
|
265507f213 | ||
|
|
a69b7e275d | ||
|
|
fcc054f64e | ||
|
|
fd45708e4b | ||
|
|
296ed4f526 | ||
|
|
890300513c | ||
|
|
f77c6a85aa | ||
|
|
b6d66ce2e3 | ||
|
|
8e4f252d32 | ||
|
|
79b7dff8d2 | ||
|
|
95e99e0417 | ||
|
|
ba6d8c07ba | ||
|
|
2a3e033ee1 | ||
|
|
da3f129977 | ||
|
|
2889c25863 | ||
|
|
24f796a2b1 | ||
|
|
fd1a951a6c | ||
|
|
aa65200713 | ||
|
|
86e26f54c7 | ||
|
|
f1c214377c | ||
|
|
aea75207dd | ||
|
|
369ea80bd4 | ||
|
|
69518b2bde | ||
|
|
1c062ab381 | ||
|
|
276cfa02b6 | ||
|
|
190840b8dc | ||
|
|
c6c3f27ecc | ||
|
|
49761d2474 | ||
|
|
07e477519b | ||
|
|
41c5e8cd6d | ||
|
|
66ceaff472 | ||
|
|
07a314767f | ||
|
|
0b75c3a03f | ||
|
|
b4dea3d64a | ||
|
|
43f9e9ab20 | ||
|
|
025f6f0f7f | ||
|
|
69051d11ec | ||
|
|
59fa786769 | ||
|
|
f38f594303 | ||
|
|
eb4d5d053f | ||
|
|
d450c32296 | ||
|
|
e84d72a4d9 | ||
|
|
06e86619c2 | ||
|
|
e257c16796 | ||
|
|
87475ccf41 | ||
|
|
8a1bce6c81 | ||
|
|
b1e966309d | ||
|
|
b95f18909e | ||
|
|
1cfc5dd077 | ||
|
|
d2e43fe6f4 | ||
|
|
426c4001ca | ||
|
|
92f1c659b9 | ||
|
|
ac75ae5184 | ||
|
|
1e52c6071e | ||
|
|
2a0dd5447a | ||
|
|
b6a1116d15 | ||
|
|
5d12ced727 | ||
|
|
2ea414922a | ||
|
|
f0b5fbb658 | ||
|
|
0b17753abe | ||
|
|
9e156428e2 | ||
|
|
99ab0f4fcb | ||
|
|
77d8cf13a3 | ||
|
|
6b21f8e82c | ||
|
|
2745d47e92 | ||
|
|
737d10191b | ||
|
|
d3b1a8e352 | ||
|
|
88f6a8e2fa | ||
|
|
b9ddcba5fd | ||
|
|
1f30317247 | ||
|
|
bfcbc73df8 | ||
|
|
4d49b68207 | ||
|
|
5aa3a46d96 | ||
|
|
b60c37b31a | ||
|
|
3d0458af31 | ||
|
|
0f6ff298dd | ||
|
|
877cf1c873 | ||
|
|
dec008e1b7 | ||
|
|
178f4bbaf9 | ||
|
|
5627adefb1 | ||
|
|
d95aaea3c5 | ||
|
|
bd4be3fc05 | ||
|
|
7a969b10bb | ||
|
|
87cec23fd0 | ||
|
|
5b4ddd26fa | ||
|
|
0d1e562f1d | ||
|
|
b00d8a073c | ||
|
|
8a88446858 | ||
|
|
26c774098d | ||
|
|
81edc83648 | ||
|
|
60b0416229 | ||
|
|
32e6684025 | ||
|
|
8ec41faf91 | ||
|
|
6c93fe86c5 | ||
|
|
8266566144 | ||
|
|
091e5c4ed8 | ||
|
|
1298d90e48 | ||
|
|
bcc58cb4cb | ||
|
|
1d8d94de82 | ||
|
|
0993ec5f08 | ||
|
|
c4688b68eb | ||
|
|
d43a0171d4 | ||
|
|
c4c8050532 | ||
|
|
3581caec76 | ||
|
|
94d6ce1006 | ||
|
|
ac70560364 | ||
|
|
6b5931dc70 | ||
|
|
7c561b6a7f | ||
|
|
30851ede4b | ||
|
|
fc3ff075ec | ||
|
|
e982eecc27 | ||
|
|
66fbcf6ac2 | ||
|
|
07d23ab08b | ||
|
|
0bcd2318c7 | ||
|
|
c9b047fbda | ||
|
|
66ae73e409 | ||
|
|
1e853ed080 | ||
|
|
6b9cebea14 | ||
|
|
9b8f28aa32 | ||
|
|
2511a49a72 | ||
|
|
014fed4405 | ||
|
|
f56c2583e8 | ||
|
|
84015697c2 | ||
|
|
c693039d14 | ||
|
|
2345ce6be2 | ||
|
|
0bf706c26f | ||
|
|
3e381002d7 | ||
|
|
cde3cec6fa | ||
|
|
07352a50b3 | ||
|
|
dc3f6432ba | ||
|
|
d6dbdfbf31 | ||
|
|
c3dfd23399 | ||
|
|
7701325969 | ||
|
|
5ed5bb15c8 | ||
|
|
6d22d0b76f | ||
|
|
487701c98c | ||
|
|
3914b54c82 | ||
|
|
a2ece33477 | ||
|
|
3411e1f599 | ||
|
|
dfcd6d0a64 | ||
|
|
16d66dc6a6 | ||
|
|
d8f00f4793 | ||
|
|
0930b4a106 | ||
|
|
d554db7e32 | ||
|
|
027e1ccb82 | ||
|
|
d1f7c1c9d7 | ||
|
|
5bd5dfecab | ||
|
|
18b9a8c844 | ||
|
|
0f19b97c5a | ||
|
|
a4db3db8ed | ||
|
|
ace734def8 | ||
|
|
5157baf166 | ||
|
|
6b7286eb62 | ||
|
|
21ddaeccf2 | ||
|
|
7e6d60c24c | ||
|
|
de76577a9f | ||
|
|
29507bc77a | ||
|
|
555efd0301 | ||
|
|
ea7d709fbb | ||
|
|
73784974ce | ||
|
|
789ee9e5e7 | ||
|
|
cb200b21c5 | ||
|
|
8130abb5ea | ||
|
|
c9acce1482 | ||
|
|
d49259855b | ||
|
|
67f298d94a | ||
|
|
4a1ec98304 | ||
|
|
0b76dfa1eb | ||
|
|
74a449ad1f | ||
|
|
9c0aa1918b | ||
|
|
2c1877a5d4 | ||
|
|
abc6f70ace | ||
|
|
ffa28e3bbd | ||
|
|
8555ab4ded | ||
|
|
69202008ce | ||
|
|
ba3d9693da | ||
|
|
06934c38c7 | ||
|
|
d52358f6c5 | ||
|
|
72b89a52fb | ||
|
|
49015f63e6 | ||
|
|
ed87445540 | ||
|
|
f6d44af146 | ||
|
|
7a2014bee3 | ||
|
|
95051e5761 | ||
|
|
f65eca6723 | ||
|
|
cd26f11859 | ||
|
|
28f1353324 | ||
|
|
f6b5c42823 | ||
|
|
ff8e63567a | ||
|
|
2665b06e95 | ||
|
|
2898d5a851 | ||
|
|
e19e80fcd8 | ||
|
|
f517d3627a | ||
|
|
9e0b99e48e | ||
|
|
df653f1e98 | ||
|
|
c6b16c06e8 | ||
|
|
c901a12789 | ||
|
|
122df8c420 | ||
|
|
1d05ae5fd3 | ||
|
|
73271d46f9 | ||
|
|
7b3e285bca | ||
|
|
bcda6d807c | ||
|
|
4d6a55243c | ||
|
|
33a585374a | ||
|
|
90433f5373 | ||
|
|
eeebc45313 | ||
|
|
9100813b79 | ||
|
|
7f5e391041 | ||
|
|
e141634da1 | ||
|
|
11eacb810e | ||
|
|
7555afb90a | ||
|
|
2ce724045b | ||
|
|
7795445ed9 | ||
|
|
d8197de4cc | ||
|
|
752103a307 | ||
|
|
a801416805 | ||
|
|
fadb22086f | ||
|
|
d2dea3d928 | ||
|
|
ee988420f3 | ||
|
|
18599be8d5 | ||
|
|
29408360fb | ||
|
|
6e7f5b922a | ||
|
|
53a3c1b17f | ||
|
|
20e0715dac | ||
|
|
662012999a | ||
|
|
1ab3186799 | ||
|
|
5f21aef786 | ||
|
|
1d881df8b2 | ||
|
|
f1e374a9bb | ||
|
|
8b097f7625 | ||
|
|
dcc943db43 | ||
|
|
9ab298dd49 | ||
|
|
bb690d9d1e | ||
|
|
f4e70e222c | ||
|
|
02f941d348 | ||
|
|
a13411c561 |
56
.github/workflows/lint.yml
vendored
Normal file
56
.github/workflows/lint.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
name: Lint
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
quick-checks:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Fetch CosyVoice
|
||||||
|
uses: actions/checkout@v1
|
||||||
|
- name: Checkout PR tip
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||||
|
# We are on a PR, so actions/checkout leaves us on a merge commit.
|
||||||
|
# Check out the actual tip of the branch.
|
||||||
|
git checkout ${{ github.event.pull_request.head.sha }}
|
||||||
|
fi
|
||||||
|
echo ::set-output name=commit_sha::$(git rev-parse HEAD)
|
||||||
|
id: get_pr_tip
|
||||||
|
- name: Ensure no tabs
|
||||||
|
run: |
|
||||||
|
(! git grep -I -l $'\t' -- . ':(exclude)*.txt' ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have tabs; please convert them to spaces"; false))
|
||||||
|
- name: Ensure no trailing whitespace
|
||||||
|
run: |
|
||||||
|
(! git grep -I -n $' $' -- . ':(exclude)*.txt' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have trailing whitespace; please remove them"; false))
|
||||||
|
|
||||||
|
flake8-py3:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v1
|
||||||
|
with:
|
||||||
|
python-version: 3.9
|
||||||
|
architecture: x64
|
||||||
|
- name: Fetch CosyVoice
|
||||||
|
uses: actions/checkout@v1
|
||||||
|
- name: Checkout PR tip
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||||
|
# We are on a PR, so actions/checkout leaves us on a merge commit.
|
||||||
|
# Check out the actual tip of the branch.
|
||||||
|
git checkout ${{ github.event.pull_request.head.sha }}
|
||||||
|
fi
|
||||||
|
echo ::set-output name=commit_sha::$(git rev-parse HEAD)
|
||||||
|
id: get_pr_tip
|
||||||
|
- name: Run flake8
|
||||||
|
run: |
|
||||||
|
set -eux
|
||||||
|
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
|
||||||
|
flake8 --version
|
||||||
|
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F722,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||||
|
if [ $? != 0 ]; then exit 1; fi
|
||||||
22
.github/workflows/stale-issues.yml
vendored
Normal file
22
.github/workflows/stale-issues.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
name: Close inactive issues
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: "30 1 * * *"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
close-issues:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/stale@v5
|
||||||
|
with:
|
||||||
|
days-before-issue-stale: 30
|
||||||
|
days-before-issue-close: 14
|
||||||
|
stale-issue-label: "stale"
|
||||||
|
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||||
|
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||||
|
days-before-pr-stale: -1
|
||||||
|
days-before-pr-close: -1
|
||||||
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -43,7 +43,10 @@ compile_commands.json
|
|||||||
|
|
||||||
# train/inference files
|
# train/inference files
|
||||||
*.wav
|
*.wav
|
||||||
|
*.m4a
|
||||||
|
*.aac
|
||||||
*.pt
|
*.pt
|
||||||
pretrained_models/*
|
pretrained_models/*
|
||||||
*_pb2_grpc.py
|
*_pb2_grpc.py
|
||||||
*_pb2.py
|
*_pb2.py
|
||||||
|
*.tar
|
||||||
283
README.md
283
README.md
@@ -1,144 +1,175 @@
|
|||||||
# CosyVoice
|

|
||||||
## 👉🏻 [CosyVoice Demos](https://fun-audio-llm.github.io/) 👈🏻
|
|
||||||
[[CosyVoice Paper](https://fun-audio-llm.github.io/pdf/CosyVoice_v1.pdf)][[CosyVoice Studio](https://www.modelscope.cn/studios/iic/CosyVoice-300M)][[CosyVoice Code](https://github.com/FunAudioLLM/CosyVoice)]
|
## 👉🏻 CosyVoice 👈🏻
|
||||||
|
|
||||||
|
**Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/pdf/2505.17589); [Modelscope](https://www.modelscope.cn/models/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [Huggingface](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
|
||||||
|
|
||||||
|
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/pdf/2412.10117); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B)
|
||||||
|
|
||||||
|
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice-300M); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice-300M)
|
||||||
|
|
||||||
|
## Highlight🔥
|
||||||
|
|
||||||
|
**Fun-CosyVoice 3.0** is an advanced text-to-speech (TTS) system based on large language models (LLM), surpassing its predecessor (CosyVoice 2.0) in content consistency, speaker similarity, and prosody naturalness. It is designed for zero-shot multilingual speech synthesis in the wild.
|
||||||
|
### Key Features
|
||||||
|
- **Language Coverage**: Covers 9 common languages (Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian), 18+ Chinese dialects/accents (Guangdong, Minnan, Sichuan, Dongbei, Shan3xi, Shan1xi, Shanghai, Tianjin, Shandong, Ningxia, Gansu, etc.) and meanwhile supports both multi-lingual/cross-lingual zero-shot voice cloning.
|
||||||
|
- **Content Consistency & Naturalness**: Achieves state-of-the-art performance in content consistency, speaker similarity, and prosody naturalness.
|
||||||
|
- **Pronunciation Inpainting**: Supports pronunciation inpainting of Chinese Pinyin and English CMU phonemes, providing more controllability and thus suitable for production use.
|
||||||
|
- **Text Normalization**: Supports reading of numbers, special symbols and various text formats without a traditional frontend module.
|
||||||
|
- **Bi-Streaming**: Support both text-in streaming and audio-out streaming, and achieves latency as low as 150ms while maintaining high-quality audio output.
|
||||||
|
- **Instruct Support**: Supports various instructions such as languages, dialects, emotions, speed, volume, etc.
|
||||||
|
|
||||||
For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVoice) and [SenseVoice space](https://www.modelscope.cn/studios/iic/SenseVoice).
|
|
||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
|
||||||
|
- [x] 2025/12
|
||||||
|
|
||||||
|
- [x] release Fun-CosyVoice3-0.5B-2512 base model, rl model and its training/inference script
|
||||||
|
- [x] release Fun-CosyVoice3-0.5B modelscope gradio space
|
||||||
|
|
||||||
|
- [x] 2025/08
|
||||||
|
|
||||||
|
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
|
||||||
|
|
||||||
|
- [x] 2025/07
|
||||||
|
|
||||||
|
- [x] release Fun-CosyVoice 3.0 eval set
|
||||||
|
|
||||||
|
- [x] 2025/05
|
||||||
|
|
||||||
|
- [x] add CosyVoice2-0.5B vllm support
|
||||||
|
|
||||||
|
- [x] 2024/12
|
||||||
|
|
||||||
|
- [x] 25hz CosyVoice2-0.5B released
|
||||||
|
|
||||||
|
- [x] 2024/09
|
||||||
|
|
||||||
|
- [x] 25hz CosyVoice-300M base model
|
||||||
|
- [x] 25hz CosyVoice-300M voice conversion function
|
||||||
|
|
||||||
|
- [x] 2024/08
|
||||||
|
|
||||||
|
- [x] Repetition Aware Sampling(RAS) inference for llm stability
|
||||||
|
- [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
|
||||||
|
|
||||||
- [x] 2024/07
|
- [x] 2024/07
|
||||||
|
|
||||||
- [x] Flow matching training support
|
- [x] Flow matching training support
|
||||||
- [x] WeTextProcessing support when ttsfrd is not avaliable
|
- [x] WeTextProcessing support when ttsfrd is not available
|
||||||
- [x] Fastapi server and client
|
- [x] Fastapi server and client
|
||||||
|
|
||||||
- [ ] 2024/08
|
## Evaluation
|
||||||
|
|
||||||
- [ ] Repetition Aware Sampling(RAS) inference for llm stability
|
| Model | Open-Source | Model Size | test-zh<br>CER (%) ↓ | test-zh<br>SS (%) ↑ | test-en<br>WER (%) ↓ | test-en<br>SS (%) ↑ | test-hard<br>CER (%) ↓ | test-hard<br>SS (%) ↑ |
|
||||||
- [ ] Streaming inference mode support, including kv cache and sdpa for rtf optimization
|
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
||||||
|
| Human | - | - | 1.26 | 75.5 | 2.14 | 73.4 | - | - |
|
||||||
|
| Seed-TTS | ❌ | - | 1.12 | 79.6 | 2.25 | 76.2 | 7.59 | 77.6 |
|
||||||
|
| MiniMax-Speech | ❌ | - | 0.83 | 78.3 | 1.65 | 69.2 | - | - |
|
||||||
|
| F5-TTS | ✅ | 0.3B | 1.52 | 74.1 | 2.00 | 64.7 | 8.67 | 71.3 |
|
||||||
|
| Spark TTS | ✅ | 0.5B | 1.2 | 66.0 | 1.98 | 57.3 | - | - |
|
||||||
|
| CosyVoice2 | ✅ | 0.5B | 1.45 | 75.7 | 2.57 | 65.9 | 6.83 | 72.4 |
|
||||||
|
| FireRedTTS2 | ✅ | 1.5B | 1.14 | 73.2 | 1.95 | 66.5 | - | - |
|
||||||
|
| Index-TTS2 | ✅ | 1.5B | 1.03 | 76.5 | 2.23 | 70.6 | 7.12 | 75.5 |
|
||||||
|
| VibeVoice-1.5B | ✅ | 1.5B | 1.16 | 74.4 | 3.04 | 68.9 | - | - |
|
||||||
|
| VibeVoice-Realtime | ✅ | 0.5B | - | - | 2.05 | 63.3 | - | - |
|
||||||
|
| HiggsAudio-v2 | ✅ | 3B | 1.50 | 74.0 | 2.44 | 67.7 | - | - |
|
||||||
|
| VoxCPM | ✅ | 0.5B | 0.93 | 77.2 | 1.85 | 72.9 | 8.87 | 73.0 |
|
||||||
|
| GLM-TTS | ✅ | 1.5B | 1.03 | 76.1 | - | - | - | - |
|
||||||
|
| GLM-TTS RL | ✅ | 1.5B | 0.89 | 76.4 | - | - | - | - |
|
||||||
|
| Fun-CosyVoice3-0.5B-2512 | ✅ | 0.5B | 1.21 | 78.0 | 2.24 | 71.8 | 6.71 | 75.8 |
|
||||||
|
| Fun-CosyVoice3-0.5B-2512_RL | ✅ | 0.5B | 0.81 | 77.4 | 1.68 | 69.5 | 5.44 | 75.0 |
|
||||||
|
|
||||||
- [ ] 2024/09
|
|
||||||
|
|
||||||
- [ ] 50hz llm model which supports 10 language
|
|
||||||
|
|
||||||
- [ ] 2024/10
|
|
||||||
|
|
||||||
- [ ] 50hz llama based llm model which supports lora finetune
|
|
||||||
|
|
||||||
- [ ] TBD
|
|
||||||
|
|
||||||
- [ ] Support more instruction mode
|
|
||||||
- [ ] Voice conversion
|
|
||||||
- [ ] Music generation
|
|
||||||
- [ ] Training script sample based on Mandarin
|
|
||||||
- [ ] CosyVoice-500M trained with more multi-lingual data
|
|
||||||
- [ ] More...
|
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
**Clone and install**
|
### Clone and install
|
||||||
|
|
||||||
- Clone the repo
|
- Clone the repo
|
||||||
``` sh
|
``` sh
|
||||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||||
# If you failed to clone submodule due to network failures, please run following command until success
|
# If you failed to clone the submodule due to network failures, please run the following command until success
|
||||||
cd CosyVoice
|
cd CosyVoice
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
```
|
```
|
||||||
|
|
||||||
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
|
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
|
||||||
- Create Conda env:
|
- Create Conda env:
|
||||||
|
|
||||||
``` sh
|
``` sh
|
||||||
conda create -n cosyvoice python=3.8
|
conda create -n cosyvoice -y python=3.10
|
||||||
conda activate cosyvoice
|
conda activate cosyvoice
|
||||||
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
|
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||||
conda install -y -c conda-forge pynini==2.1.5
|
|
||||||
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
|
||||||
|
|
||||||
# If you encounter sox compatibility issues
|
# If you encounter sox compatibility issues
|
||||||
# ubuntu
|
# ubuntu
|
||||||
sudo apt-get install sox libsox-dev
|
sudo apt-get install sox libsox-dev
|
||||||
# centos
|
# centos
|
||||||
sudo yum install sox sox-devel
|
sudo yum install sox sox-devel
|
||||||
```
|
```
|
||||||
|
|
||||||
**Model download**
|
### Model download
|
||||||
|
|
||||||
We strongly recommend that you download our pretrained `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
|
We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
|
||||||
|
|
||||||
If you are expert in this field, and you are only interested in training your own CosyVoice model from scratch, you can skip this step.
|
|
||||||
|
|
||||||
``` python
|
``` python
|
||||||
# SDK模型下载
|
# modelscope SDK model download
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
|
||||||
|
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
||||||
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
||||||
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
||||||
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||||
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
||||||
|
|
||||||
|
# for oversea users, huggingface SDK model download
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
|
||||||
|
snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
||||||
|
snapshot_download('FunAudioLLM/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
||||||
|
snapshot_download('FunAudioLLM/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
||||||
|
snapshot_download('FunAudioLLM/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||||
|
snapshot_download('FunAudioLLM/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
||||||
```
|
```
|
||||||
|
|
||||||
``` sh
|
Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
|
||||||
# git模型下载,请确保已安装git lfs
|
|
||||||
mkdir -p pretrained_models
|
|
||||||
git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
|
|
||||||
git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
|
|
||||||
git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
|
|
||||||
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
|
||||||
```
|
|
||||||
|
|
||||||
Optionaly, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
|
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use wetext by default.
|
||||||
|
|
||||||
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
|
|
||||||
|
|
||||||
``` sh
|
``` sh
|
||||||
cd pretrained_models/CosyVoice-ttsfrd/
|
cd pretrained_models/CosyVoice-ttsfrd/
|
||||||
unzip resource.zip -d .
|
unzip resource.zip -d .
|
||||||
pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
|
pip install ttsfrd_dependency-0.1-py3-none-any.whl
|
||||||
|
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
**Basic Usage**
|
### Basic Usage
|
||||||
|
|
||||||
For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
|
We strongly recommend using `Fun-CosyVoice3-0.5B` for better performance.
|
||||||
For sft inference, please use `CosyVoice-300M-SFT` model.
|
Follow the code in `example.py` for detailed usage of each model.
|
||||||
For instruct inference, please use `CosyVoice-300M-Instruct` model.
|
```sh
|
||||||
First, add `third_party/Matcha-TTS` to your `PYTHONPATH`.
|
python example.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### vLLM Usage
|
||||||
|
CosyVoice2/3 now supports **vLLM 0.11.x+ (V1 engine)** and **vLLM 0.9.0 (legacy)**.
|
||||||
|
Older vllm version(<0.9.0) do not support CosyVoice inference, and versions in between (e.g., 0.10.x) are not tested.
|
||||||
|
|
||||||
|
Notice that `vllm` has a lot of specific requirements. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
|
||||||
|
|
||||||
``` sh
|
``` sh
|
||||||
export PYTHONPATH=third_party/Matcha-TTS
|
conda create -n cosyvoice_vllm --clone cosyvoice
|
||||||
|
conda activate cosyvoice_vllm
|
||||||
|
# for vllm==0.9.0
|
||||||
|
pip install vllm==v0.9.0 transformers==4.51.3 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||||
|
# for vllm>=0.11.0
|
||||||
|
pip install vllm==v0.11.0 transformers==4.57.1 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||||
|
python vllm_example.py
|
||||||
```
|
```
|
||||||
|
|
||||||
``` python
|
#### Start web demo
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
||||||
from cosyvoice.utils.file_utils import load_wav
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
|
|
||||||
# sft usage
|
|
||||||
print(cosyvoice.list_avaliable_spks())
|
|
||||||
output = cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女')
|
|
||||||
torchaudio.save('sft.wav', output['tts_speech'], 22050)
|
|
||||||
|
|
||||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
|
|
||||||
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
|
||||||
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
|
||||||
output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k)
|
|
||||||
torchaudio.save('zero_shot.wav', output['tts_speech'], 22050)
|
|
||||||
# cross_lingual usage
|
|
||||||
prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
|
|
||||||
output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k)
|
|
||||||
torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050)
|
|
||||||
|
|
||||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
|
|
||||||
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
|
||||||
output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
|
|
||||||
torchaudio.save('instruct.wav', output['tts_speech'], 22050)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Start web demo**
|
|
||||||
|
|
||||||
You can use our web demo page to get familiar with CosyVoice quickly.
|
You can use our web demo page to get familiar with CosyVoice quickly.
|
||||||
We support sft/zero_shot/cross_lingual/instruct inference in web demo.
|
|
||||||
|
|
||||||
Please see the demo website for details.
|
Please see the demo website for details.
|
||||||
|
|
||||||
@@ -147,15 +178,14 @@ Please see the demo website for details.
|
|||||||
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
|
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
|
||||||
```
|
```
|
||||||
|
|
||||||
**Advanced Usage**
|
#### Advanced Usage
|
||||||
|
|
||||||
For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
|
For advanced users, we have provided training and inference scripts in `examples/libritts`.
|
||||||
You can get familiar with CosyVoice following this recipie.
|
|
||||||
|
|
||||||
**Build for deployment**
|
#### Build for deployment
|
||||||
|
|
||||||
Optionally, if you want to use grpc for service deployment,
|
Optionally, if you want service deployment,
|
||||||
you can run following steps. Otherwise, you can just ignore this step.
|
You can run the following steps.
|
||||||
|
|
||||||
``` sh
|
``` sh
|
||||||
cd runtime/python
|
cd runtime/python
|
||||||
@@ -163,12 +193,23 @@ docker build -t cosyvoice:v1.0 .
|
|||||||
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
|
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
|
||||||
# for grpc usage
|
# for grpc usage
|
||||||
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
||||||
python3 grpc/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||||
# for fastapi usage
|
# for fastapi usage
|
||||||
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
|
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
|
||||||
python3 fastapi/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Using Nvidia TensorRT-LLM for deployment
|
||||||
|
|
||||||
|
Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
|
||||||
|
To quick start:
|
||||||
|
|
||||||
|
``` sh
|
||||||
|
cd runtime/triton_trtllm
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
|
||||||
|
|
||||||
## Discussion & Communication
|
## Discussion & Communication
|
||||||
|
|
||||||
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
||||||
@@ -185,5 +226,39 @@ You can also scan the QR code to join our official Dingding chat group.
|
|||||||
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
|
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
|
||||||
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
|
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
|
||||||
|
|
||||||
|
## Citations
|
||||||
|
|
||||||
|
``` bibtex
|
||||||
|
@article{du2024cosyvoice,
|
||||||
|
title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
|
||||||
|
author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
|
||||||
|
journal={arXiv preprint arXiv:2407.05407},
|
||||||
|
year={2024}
|
||||||
|
}
|
||||||
|
|
||||||
|
@article{du2024cosyvoice,
|
||||||
|
title={Cosyvoice 2: Scalable streaming speech synthesis with large language models},
|
||||||
|
author={Du, Zhihao and Wang, Yuxuan and Chen, Qian and Shi, Xian and Lv, Xiang and Zhao, Tianyu and Gao, Zhifu and Yang, Yexin and Gao, Changfeng and Wang, Hui and others},
|
||||||
|
journal={arXiv preprint arXiv:2412.10117},
|
||||||
|
year={2024}
|
||||||
|
}
|
||||||
|
|
||||||
|
@article{du2025cosyvoice,
|
||||||
|
title={CosyVoice 3: Towards In-the-wild Speech Generation via Scaling-up and Post-training},
|
||||||
|
author={Du, Zhihao and Gao, Changfeng and Wang, Yuxuan and Yu, Fan and Zhao, Tianyu and Wang, Hao and Lv, Xiang and Wang, Hui and Shi, Xian and An, Keyu and others},
|
||||||
|
journal={arXiv preprint arXiv:2505.17589},
|
||||||
|
year={2025}
|
||||||
|
}
|
||||||
|
|
||||||
|
@inproceedings{lyu2025build,
|
||||||
|
title={Build LLM-Based Zero-Shot Streaming TTS System with Cosyvoice},
|
||||||
|
author={Lyu, Xiang and Wang, Yuxuan and Zhao, Tianyu and Wang, Hao and Liu, Huadai and Du, Zhihao},
|
||||||
|
booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
|
||||||
|
pages={1--2},
|
||||||
|
year={2025},
|
||||||
|
organization={IEEE}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Disclaimer
|
## Disclaimer
|
||||||
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 94 KiB After Width: | Height: | Size: 120 KiB |
BIN
asset/zero_shot_prompt.wav
Normal file
BIN
asset/zero_shot_prompt.wav
Normal file
Binary file not shown.
93
cosyvoice/bin/average_model.py
Normal file
93
cosyvoice/bin/average_model.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
||||||
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description='average model')
|
||||||
|
parser.add_argument('--dst_model', required=True, help='averaged model')
|
||||||
|
parser.add_argument('--src_path',
|
||||||
|
required=True,
|
||||||
|
help='src model path for average')
|
||||||
|
parser.add_argument('--val_best',
|
||||||
|
action="store_true",
|
||||||
|
help='averaged model')
|
||||||
|
parser.add_argument('--num',
|
||||||
|
default=5,
|
||||||
|
type=int,
|
||||||
|
help='nums for averaged model')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
val_scores = []
|
||||||
|
if args.val_best:
|
||||||
|
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
||||||
|
yamls = [
|
||||||
|
f for f in yamls
|
||||||
|
if not (os.path.basename(f).startswith('train')
|
||||||
|
or os.path.basename(f).startswith('init'))
|
||||||
|
]
|
||||||
|
for y in yamls:
|
||||||
|
with open(y, 'r') as f:
|
||||||
|
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
||||||
|
loss = float(dic_yaml['loss_dict']['loss'])
|
||||||
|
epoch = int(dic_yaml['epoch'])
|
||||||
|
step = int(dic_yaml['step'])
|
||||||
|
tag = dic_yaml['tag']
|
||||||
|
val_scores += [[epoch, step, loss, tag]]
|
||||||
|
sorted_val_scores = sorted(val_scores,
|
||||||
|
key=lambda x: x[2],
|
||||||
|
reverse=False)
|
||||||
|
print("best val (epoch, step, loss, tag) = " +
|
||||||
|
str(sorted_val_scores[:args.num]))
|
||||||
|
path_list = [
|
||||||
|
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
||||||
|
for score in sorted_val_scores[:args.num]
|
||||||
|
]
|
||||||
|
print(path_list)
|
||||||
|
avg = {}
|
||||||
|
num = args.num
|
||||||
|
assert num == len(path_list)
|
||||||
|
for path in path_list:
|
||||||
|
print('Processing {}'.format(path))
|
||||||
|
states = torch.load(path, map_location=torch.device('cpu'))
|
||||||
|
for k in states.keys():
|
||||||
|
if k not in ['step', 'epoch']:
|
||||||
|
if k not in avg.keys():
|
||||||
|
avg[k] = states[k].clone()
|
||||||
|
else:
|
||||||
|
avg[k] += states[k]
|
||||||
|
# average
|
||||||
|
for k in avg.keys():
|
||||||
|
if avg[k] is not None:
|
||||||
|
# pytorch 1.6 use true_divide instead of /=
|
||||||
|
avg[k] = torch.true_divide(avg[k], num)
|
||||||
|
print('Saving to {}'.format(args.dst_model))
|
||||||
|
torch.save(avg, args.dst_model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
99
cosyvoice/bin/export_jit.py
Normal file
99
cosyvoice/bin/export_jit.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||||
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||||
|
from cosyvoice.cli.cosyvoice import AutoModel
|
||||||
|
from cosyvoice.utils.file_utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description='export your model for deployment')
|
||||||
|
parser.add_argument('--model_dir',
|
||||||
|
type=str,
|
||||||
|
default='pretrained_models/CosyVoice-300M',
|
||||||
|
help='local path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimized_script(model, preserved_attrs=[]):
|
||||||
|
script = torch.jit.script(model)
|
||||||
|
if preserved_attrs != []:
|
||||||
|
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
|
||||||
|
else:
|
||||||
|
script = torch.jit.freeze(script)
|
||||||
|
script = torch.jit.optimize_for_inference(script)
|
||||||
|
return script
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
|
||||||
|
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
torch._C._jit_set_profiling_executor(False)
|
||||||
|
|
||||||
|
model = AutoModel(model_dir=args.model_dir)
|
||||||
|
|
||||||
|
if model.__class__.__name__ == 'CosyVoice':
|
||||||
|
# 1. export llm text_encoder
|
||||||
|
llm_text_encoder = model.model.llm.text_encoder
|
||||||
|
script = get_optimized_script(llm_text_encoder)
|
||||||
|
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
|
||||||
|
script = get_optimized_script(llm_text_encoder.half())
|
||||||
|
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
||||||
|
logging.info('successfully export llm_text_encoder')
|
||||||
|
|
||||||
|
# 2. export llm llm
|
||||||
|
llm_llm = model.model.llm.llm
|
||||||
|
script = get_optimized_script(llm_llm, ['forward_chunk'])
|
||||||
|
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
|
||||||
|
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
|
||||||
|
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
||||||
|
logging.info('successfully export llm_llm')
|
||||||
|
|
||||||
|
# 3. export flow encoder
|
||||||
|
flow_encoder = model.model.flow.encoder
|
||||||
|
script = get_optimized_script(flow_encoder)
|
||||||
|
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||||
|
script = get_optimized_script(flow_encoder.half())
|
||||||
|
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||||
|
logging.info('successfully export flow_encoder')
|
||||||
|
elif model.__class__.__name__ == 'CosyVoice2':
|
||||||
|
# 1. export flow encoder
|
||||||
|
flow_encoder = model.model.flow.encoder
|
||||||
|
script = get_optimized_script(flow_encoder)
|
||||||
|
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||||
|
script = get_optimized_script(flow_encoder.half())
|
||||||
|
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||||
|
logging.info('successfully export flow_encoder')
|
||||||
|
else:
|
||||||
|
raise ValueError('unsupported model type')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
114
cosyvoice/bin/export_onnx.py
Normal file
114
cosyvoice/bin/export_onnx.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
|
||||||
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import onnxruntime
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||||
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||||
|
from cosyvoice.cli.cosyvoice import AutoModel
|
||||||
|
from cosyvoice.utils.file_utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
||||||
|
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
||||||
|
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
||||||
|
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
||||||
|
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
||||||
|
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
||||||
|
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
||||||
|
return x, mask, mu, t, spks, cond
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description='export your model for deployment')
|
||||||
|
parser.add_argument('--model_dir',
|
||||||
|
type=str,
|
||||||
|
default='pretrained_models/CosyVoice-300M',
|
||||||
|
help='local path')
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
|
||||||
|
model = AutoModel(model_dir=args.model_dir)
|
||||||
|
|
||||||
|
# 1. export flow decoder estimator
|
||||||
|
estimator = model.model.flow.decoder.estimator
|
||||||
|
estimator.eval()
|
||||||
|
|
||||||
|
device = model.model.device
|
||||||
|
batch_size, seq_len = 2, 256
|
||||||
|
out_channels = model.model.flow.decoder.estimator.out_channels
|
||||||
|
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
||||||
|
torch.onnx.export(
|
||||||
|
estimator,
|
||||||
|
(x, mask, mu, t, spks, cond),
|
||||||
|
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||||
|
export_params=True,
|
||||||
|
opset_version=18,
|
||||||
|
do_constant_folding=True,
|
||||||
|
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
||||||
|
output_names=['estimator_out'],
|
||||||
|
dynamic_axes={
|
||||||
|
'x': {2: 'seq_len'},
|
||||||
|
'mask': {2: 'seq_len'},
|
||||||
|
'mu': {2: 'seq_len'},
|
||||||
|
'cond': {2: 'seq_len'},
|
||||||
|
'estimator_out': {2: 'seq_len'},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. test computation consistency
|
||||||
|
option = onnxruntime.SessionOptions()
|
||||||
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
option.intra_op_num_threads = 1
|
||||||
|
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||||
|
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||||
|
sess_options=option, providers=providers)
|
||||||
|
|
||||||
|
for _ in tqdm(range(10)):
|
||||||
|
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
||||||
|
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
||||||
|
ort_inputs = {
|
||||||
|
'x': x.cpu().numpy(),
|
||||||
|
'mask': mask.cpu().numpy(),
|
||||||
|
'mu': mu.cpu().numpy(),
|
||||||
|
't': t.cpu().numpy(),
|
||||||
|
'spks': spks.cpu().numpy(),
|
||||||
|
'cond': cond.cpu().numpy()
|
||||||
|
}
|
||||||
|
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
||||||
|
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
||||||
|
logging.info('successfully export estimator')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
import torchaudio
|
|
||||||
from hyperpyyaml import load_hyperpyyaml
|
|
||||||
from tqdm import tqdm
|
|
||||||
from cosyvoice.cli.model import CosyVoiceModel
|
|
||||||
|
|
||||||
from cosyvoice.dataset.dataset import Dataset
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser(description='inference with your model')
|
|
||||||
parser.add_argument('--config', required=True, help='config file')
|
|
||||||
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
|
||||||
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
|
||||||
parser.add_argument('--tts_text', required=True, help='tts input file')
|
|
||||||
parser.add_argument('--llm_model', required=True, help='llm model file')
|
|
||||||
parser.add_argument('--flow_model', required=True, help='flow model file')
|
|
||||||
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
|
||||||
parser.add_argument('--gpu',
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help='gpu id for this rank, -1 for cpu')
|
|
||||||
parser.add_argument('--mode',
|
|
||||||
default='sft',
|
|
||||||
choices=['sft', 'zero_shot'],
|
|
||||||
help='inference mode')
|
|
||||||
parser.add_argument('--result_dir', required=True, help='asr result file')
|
|
||||||
args = parser.parse_args()
|
|
||||||
print(args)
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = get_args()
|
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
|
||||||
format='%(asctime)s %(levelname)s %(message)s')
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
|
||||||
|
|
||||||
# Init cosyvoice models from configs
|
|
||||||
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
|
||||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
|
||||||
with open(args.config, 'r') as f:
|
|
||||||
configs = load_hyperpyyaml(f)
|
|
||||||
|
|
||||||
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
|
||||||
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
|
||||||
|
|
||||||
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
|
||||||
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
|
||||||
|
|
||||||
del configs
|
|
||||||
os.makedirs(args.result_dir, exist_ok=True)
|
|
||||||
fn = os.path.join(args.result_dir, 'wav.scp')
|
|
||||||
f = open(fn, 'w')
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch_idx, batch in tqdm(enumerate(test_data_loader)):
|
|
||||||
utts = batch["utts"]
|
|
||||||
assert len(utts) == 1, "inference mode only support batchsize 1"
|
|
||||||
text = batch["text"]
|
|
||||||
text_token = batch["text_token"].to(device)
|
|
||||||
text_token_len = batch["text_token_len"].to(device)
|
|
||||||
tts_text = batch["tts_text"]
|
|
||||||
tts_index = batch["tts_index"]
|
|
||||||
tts_text_token = batch["tts_text_token"].to(device)
|
|
||||||
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
|
||||||
speech_token = batch["speech_token"].to(device)
|
|
||||||
speech_token_len = batch["speech_token_len"].to(device)
|
|
||||||
speech_feat = batch["speech_feat"].to(device)
|
|
||||||
speech_feat_len = batch["speech_feat_len"].to(device)
|
|
||||||
utt_embedding = batch["utt_embedding"].to(device)
|
|
||||||
spk_embedding = batch["spk_embedding"].to(device)
|
|
||||||
if args.mode == 'sft':
|
|
||||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
|
||||||
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
|
||||||
else:
|
|
||||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
|
||||||
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
|
||||||
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
|
||||||
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
|
||||||
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
|
||||||
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
|
||||||
model_output = model.inference(**model_input)
|
|
||||||
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
|
||||||
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
|
||||||
torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
|
|
||||||
f.write('{} {}\n'.format(tts_key, tts_fn))
|
|
||||||
f.flush()
|
|
||||||
f.close()
|
|
||||||
logging.info('Result wav.scp saved in {}'.format(fn))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@@ -18,6 +18,7 @@ import datetime
|
|||||||
import logging
|
import logging
|
||||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import deepspeed
|
import deepspeed
|
||||||
@@ -26,6 +27,7 @@ from hyperpyyaml import load_hyperpyyaml
|
|||||||
|
|
||||||
from torch.distributed.elastic.multiprocessing.errors import record
|
from torch.distributed.elastic.multiprocessing.errors import record
|
||||||
|
|
||||||
|
from cosyvoice.utils.losses import DPOLoss
|
||||||
from cosyvoice.utils.executor import Executor
|
from cosyvoice.utils.executor import Executor
|
||||||
from cosyvoice.utils.train_utils import (
|
from cosyvoice.utils.train_utils import (
|
||||||
init_distributed,
|
init_distributed,
|
||||||
@@ -42,9 +44,12 @@ def get_args():
|
|||||||
choices=['torch_ddp', 'deepspeed'],
|
choices=['torch_ddp', 'deepspeed'],
|
||||||
help='Engine for paralleled training')
|
help='Engine for paralleled training')
|
||||||
parser.add_argument('--model', required=True, help='model which will be trained')
|
parser.add_argument('--model', required=True, help='model which will be trained')
|
||||||
|
parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
|
||||||
parser.add_argument('--config', required=True, help='config file')
|
parser.add_argument('--config', required=True, help='config file')
|
||||||
parser.add_argument('--train_data', required=True, help='train data file')
|
parser.add_argument('--train_data', required=True, help='train data file')
|
||||||
parser.add_argument('--cv_data', required=True, help='cv data file')
|
parser.add_argument('--cv_data', required=True, help='cv data file')
|
||||||
|
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
||||||
|
parser.add_argument('--onnx_path', required=False, help='onnx path, which is required for online feature extraction')
|
||||||
parser.add_argument('--checkpoint', help='checkpoint model')
|
parser.add_argument('--checkpoint', help='checkpoint model')
|
||||||
parser.add_argument('--model_dir', required=True, help='save model dir')
|
parser.add_argument('--model_dir', required=True, help='save model dir')
|
||||||
parser.add_argument('--tensorboard_dir',
|
parser.add_argument('--tensorboard_dir',
|
||||||
@@ -67,13 +72,21 @@ def get_args():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
help='Use pinned memory buffers used for reading')
|
help='Use pinned memory buffers used for reading')
|
||||||
|
parser.add_argument('--use_amp',
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help='Use automatic mixed precision training')
|
||||||
|
parser.add_argument('--dpo',
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help='Use Direct Preference Optimization')
|
||||||
parser.add_argument('--deepspeed.save_states',
|
parser.add_argument('--deepspeed.save_states',
|
||||||
dest='save_states',
|
dest='save_states',
|
||||||
default='model_only',
|
default='model_only',
|
||||||
choices=['model_only', 'model+optimizer'],
|
choices=['model_only', 'model+optimizer'],
|
||||||
help='save model/optimizer states')
|
help='save model/optimizer states')
|
||||||
parser.add_argument('--timeout',
|
parser.add_argument('--timeout',
|
||||||
default=30,
|
default=60,
|
||||||
type=int,
|
type=int,
|
||||||
help='timeout (in seconds) of cosyvoice_join.')
|
help='timeout (in seconds) of cosyvoice_join.')
|
||||||
parser = deepspeed.add_config_arguments(parser)
|
parser = deepspeed.add_config_arguments(parser)
|
||||||
@@ -84,12 +97,21 @@ def get_args():
|
|||||||
@record
|
@record
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
os.environ['onnx_path'] = args.onnx_path
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
format='%(asctime)s %(levelname)s %(message)s')
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
# gan train has some special initialization logic
|
||||||
|
gan = True if args.model == 'hifigan' else False
|
||||||
|
|
||||||
override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
|
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
||||||
|
if gan is True:
|
||||||
|
override_dict.pop('hift')
|
||||||
|
if args.qwen_pretrain_path is not None:
|
||||||
|
override_dict['qwen_pretrain_path'] = args.qwen_pretrain_path
|
||||||
with open(args.config, 'r') as f:
|
with open(args.config, 'r') as f:
|
||||||
configs = load_hyperpyyaml(f, overrides=override_dict)
|
configs = load_hyperpyyaml(f, overrides=override_dict)
|
||||||
|
if gan is True:
|
||||||
|
configs['train_conf'] = configs['train_conf_gan']
|
||||||
configs['train_conf'].update(vars(args))
|
configs['train_conf'].update(vars(args))
|
||||||
|
|
||||||
# Init env for ddp
|
# Init env for ddp
|
||||||
@@ -97,7 +119,7 @@ def main():
|
|||||||
|
|
||||||
# Get dataset & dataloader
|
# Get dataset & dataloader
|
||||||
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
||||||
init_dataset_and_dataloader(args, configs)
|
init_dataset_and_dataloader(args, configs, gan, args.dpo)
|
||||||
|
|
||||||
# Do some sanity checks and save config to arsg.model_dir
|
# Do some sanity checks and save config to arsg.model_dir
|
||||||
configs = check_modify_and_save_config(args, configs)
|
configs = check_modify_and_save_config(args, configs)
|
||||||
@@ -106,31 +128,68 @@ def main():
|
|||||||
writer = init_summarywriter(args)
|
writer = init_summarywriter(args)
|
||||||
|
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
|
if args.dpo is True:
|
||||||
|
configs[args.model].forward = configs[args.model].forward_dpo
|
||||||
model = configs[args.model]
|
model = configs[args.model]
|
||||||
|
start_step, start_epoch = 0, -1
|
||||||
if args.checkpoint is not None:
|
if args.checkpoint is not None:
|
||||||
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
|
if os.path.exists(args.checkpoint):
|
||||||
|
state_dict = torch.load(args.checkpoint, map_location='cpu')
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
if 'step' in state_dict:
|
||||||
|
start_step = state_dict['step']
|
||||||
|
if 'epoch' in state_dict:
|
||||||
|
start_epoch = state_dict['epoch']
|
||||||
|
else:
|
||||||
|
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
||||||
|
|
||||||
# Dispatch model from cpu to gpu
|
# Dispatch model from cpu to gpu
|
||||||
model = wrap_cuda_model(args, model)
|
model = wrap_cuda_model(args, model)
|
||||||
|
|
||||||
# Get optimizer & scheduler
|
# Get optimizer & scheduler
|
||||||
model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
|
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
||||||
|
scheduler.set_step(start_step)
|
||||||
|
if scheduler_d is not None:
|
||||||
|
scheduler_d.set_step(start_step)
|
||||||
|
|
||||||
# Save init checkpoints
|
# Save init checkpoints
|
||||||
info_dict = deepcopy(configs['train_conf'])
|
info_dict = deepcopy(configs['train_conf'])
|
||||||
|
info_dict['step'] = start_step
|
||||||
|
info_dict['epoch'] = start_epoch
|
||||||
save_model(model, 'init', info_dict)
|
save_model(model, 'init', info_dict)
|
||||||
|
|
||||||
|
# DPO related
|
||||||
|
if args.dpo is True:
|
||||||
|
ref_model = deepcopy(configs[args.model])
|
||||||
|
state_dict = torch.load(args.ref_model, map_location='cpu')
|
||||||
|
ref_model.load_state_dict(state_dict, strict=False)
|
||||||
|
dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
|
||||||
|
# NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
|
||||||
|
ref_model = wrap_cuda_model(args, ref_model)
|
||||||
|
else:
|
||||||
|
ref_model, dpo_loss = None, None
|
||||||
|
|
||||||
# Get executor
|
# Get executor
|
||||||
executor = Executor()
|
executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
|
||||||
|
executor.step = start_step
|
||||||
|
|
||||||
|
# Init scaler, used for pytorch amp mixed precision training
|
||||||
|
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
||||||
|
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
||||||
|
|
||||||
# Start training loop
|
# Start training loop
|
||||||
for epoch in range(info_dict['max_epoch']):
|
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
||||||
executor.epoch = epoch
|
executor.epoch = epoch
|
||||||
train_dataset.set_epoch(epoch)
|
train_dataset.set_epoch(epoch)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
||||||
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
|
if gan is True:
|
||||||
|
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||||
|
writer, info_dict, scaler, group_join)
|
||||||
|
else:
|
||||||
|
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
|
||||||
dist.destroy_process_group(group_join)
|
dist.destroy_process_group(group_join)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -12,72 +12,229 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import torch
|
import time
|
||||||
|
from typing import Generator
|
||||||
|
from tqdm import tqdm
|
||||||
from hyperpyyaml import load_hyperpyyaml
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
import torch
|
||||||
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
||||||
from cosyvoice.cli.model import CosyVoiceModel
|
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
|
||||||
|
from cosyvoice.utils.file_utils import logging
|
||||||
|
from cosyvoice.utils.class_utils import get_model_type
|
||||||
|
|
||||||
|
|
||||||
class CosyVoice:
|
class CosyVoice:
|
||||||
|
|
||||||
def __init__(self, model_dir):
|
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||||
instruct = True if '-Instruct' in model_dir else False
|
|
||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
|
self.fp16 = fp16
|
||||||
if not os.path.exists(model_dir):
|
if not os.path.exists(model_dir):
|
||||||
model_dir = snapshot_download(model_dir)
|
model_dir = snapshot_download(model_dir)
|
||||||
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
|
||||||
|
if not os.path.exists(hyper_yaml_path):
|
||||||
|
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||||
|
with open(hyper_yaml_path, 'r') as f:
|
||||||
configs = load_hyperpyyaml(f)
|
configs = load_hyperpyyaml(f)
|
||||||
|
assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
||||||
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||||
configs['feat_extractor'],
|
configs['feat_extractor'],
|
||||||
'{}/campplus.onnx'.format(model_dir),
|
'{}/campplus.onnx'.format(model_dir),
|
||||||
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
||||||
'{}/spk2info.pt'.format(model_dir),
|
'{}/spk2info.pt'.format(model_dir),
|
||||||
instruct,
|
|
||||||
configs['allowed_special'])
|
configs['allowed_special'])
|
||||||
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
self.sample_rate = configs['sample_rate']
|
||||||
|
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||||
|
load_jit, load_trt, fp16 = False, False, False
|
||||||
|
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||||
|
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/hift.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
|
if load_jit:
|
||||||
|
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
|
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
|
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
|
if load_trt:
|
||||||
|
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
|
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||||
|
trt_concurrent,
|
||||||
|
self.fp16)
|
||||||
del configs
|
del configs
|
||||||
|
|
||||||
def list_avaliable_spks(self):
|
def list_available_spks(self):
|
||||||
spks = list(self.frontend.spk2info.keys())
|
spks = list(self.frontend.spk2info.keys())
|
||||||
return spks
|
return spks
|
||||||
|
|
||||||
def inference_sft(self, tts_text, spk_id):
|
def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
|
||||||
tts_speeches = []
|
assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
|
||||||
for i in self.frontend.text_normalize(tts_text, split=True):
|
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
|
||||||
|
del model_input['text']
|
||||||
|
del model_input['text_len']
|
||||||
|
self.frontend.spk2info[zero_shot_spk_id] = model_input
|
||||||
|
return True
|
||||||
|
|
||||||
|
def save_spkinfo(self):
|
||||||
|
torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
|
||||||
|
|
||||||
|
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
||||||
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
model_input = self.frontend.frontend_sft(i, spk_id)
|
model_input = self.frontend.frontend_sft(i, spk_id)
|
||||||
model_output = self.model.inference(**model_input)
|
start_time = time.time()
|
||||||
tts_speeches.append(model_output['tts_speech'])
|
logging.info('synthesis text {}'.format(i))
|
||||||
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||||
|
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||||
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
|
yield model_output
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||||
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
if self.__class__.__name__ == 'CosyVoice3' and '<|endofprompt|>' not in prompt_text + tts_text:
|
||||||
tts_speeches = []
|
logging.warning('<|endofprompt|> not found in CosyVoice3 inference, check your input text')
|
||||||
for i in self.frontend.text_normalize(tts_text, split=True):
|
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
||||||
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
model_output = self.model.inference(**model_input)
|
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
||||||
tts_speeches.append(model_output['tts_speech'])
|
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
||||||
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
|
||||||
|
start_time = time.time()
|
||||||
|
logging.info('synthesis text {}'.format(i))
|
||||||
|
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||||
|
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||||
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
|
yield model_output
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
def inference_cross_lingual(self, tts_text, prompt_speech_16k):
|
def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||||
if self.frontend.instruct is True:
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
|
||||||
tts_speeches = []
|
start_time = time.time()
|
||||||
for i in self.frontend.text_normalize(tts_text, split=True):
|
logging.info('synthesis text {}'.format(i))
|
||||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||||
model_output = self.model.inference(**model_input)
|
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||||
tts_speeches.append(model_output['tts_speech'])
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
yield model_output
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
def inference_instruct(self, tts_text, spk_id, instruct_text):
|
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
||||||
if self.frontend.instruct is False:
|
assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
|
||||||
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
||||||
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
tts_speeches = []
|
|
||||||
for i in self.frontend.text_normalize(tts_text, split=True):
|
|
||||||
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
||||||
model_output = self.model.inference(**model_input)
|
start_time = time.time()
|
||||||
tts_speeches.append(model_output['tts_speech'])
|
logging.info('synthesis text {}'.format(i))
|
||||||
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||||
|
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||||
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
|
yield model_output
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
|
||||||
|
model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
|
||||||
|
start_time = time.time()
|
||||||
|
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||||
|
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||||
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
|
yield model_output
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice2(CosyVoice):
|
||||||
|
|
||||||
|
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
||||||
|
self.model_dir = model_dir
|
||||||
|
self.fp16 = fp16
|
||||||
|
if not os.path.exists(model_dir):
|
||||||
|
model_dir = snapshot_download(model_dir)
|
||||||
|
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
|
||||||
|
if not os.path.exists(hyper_yaml_path):
|
||||||
|
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||||
|
with open(hyper_yaml_path, 'r') as f:
|
||||||
|
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||||
|
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
|
||||||
|
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||||
|
configs['feat_extractor'],
|
||||||
|
'{}/campplus.onnx'.format(model_dir),
|
||||||
|
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
||||||
|
'{}/spk2info.pt'.format(model_dir),
|
||||||
|
configs['allowed_special'])
|
||||||
|
self.sample_rate = configs['sample_rate']
|
||||||
|
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
|
||||||
|
load_jit, load_trt, load_vllm, fp16 = False, False, False, False
|
||||||
|
logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
|
||||||
|
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||||
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
|
'{}/flow.pt'.format(model_dir),
|
||||||
|
'{}/hift.pt'.format(model_dir))
|
||||||
|
if load_vllm:
|
||||||
|
self.model.load_vllm('{}/vllm'.format(model_dir))
|
||||||
|
if load_jit:
|
||||||
|
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
|
if load_trt:
|
||||||
|
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
|
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||||
|
trt_concurrent,
|
||||||
|
self.fp16)
|
||||||
|
del configs
|
||||||
|
|
||||||
|
def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||||
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
|
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
|
||||||
|
start_time = time.time()
|
||||||
|
logging.info('synthesis text {}'.format(i))
|
||||||
|
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||||
|
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||||
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
|
yield model_output
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice3(CosyVoice2):
|
||||||
|
|
||||||
|
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
||||||
|
self.model_dir = model_dir
|
||||||
|
self.fp16 = fp16
|
||||||
|
if not os.path.exists(model_dir):
|
||||||
|
model_dir = snapshot_download(model_dir)
|
||||||
|
hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
|
||||||
|
if not os.path.exists(hyper_yaml_path):
|
||||||
|
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||||
|
with open(hyper_yaml_path, 'r') as f:
|
||||||
|
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||||
|
assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
|
||||||
|
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||||
|
configs['feat_extractor'],
|
||||||
|
'{}/campplus.onnx'.format(model_dir),
|
||||||
|
'{}/speech_tokenizer_v3.onnx'.format(model_dir),
|
||||||
|
'{}/spk2info.pt'.format(model_dir),
|
||||||
|
configs['allowed_special'])
|
||||||
|
self.sample_rate = configs['sample_rate']
|
||||||
|
if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
|
||||||
|
load_trt, fp16 = False, False
|
||||||
|
logging.warning('no cuda device, set load_trt/fp16 to False')
|
||||||
|
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||||
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
|
'{}/flow.pt'.format(model_dir),
|
||||||
|
'{}/hift.pt'.format(model_dir))
|
||||||
|
if load_vllm:
|
||||||
|
self.model.load_vllm('{}/vllm'.format(model_dir))
|
||||||
|
if load_trt:
|
||||||
|
if self.fp16 is True:
|
||||||
|
logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
|
||||||
|
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
|
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||||
|
trt_concurrent,
|
||||||
|
self.fp16)
|
||||||
|
del configs
|
||||||
|
|
||||||
|
|
||||||
|
def AutoModel(**kwargs):
|
||||||
|
if not os.path.exists(kwargs['model_dir']):
|
||||||
|
kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
|
||||||
|
if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
|
||||||
|
return CosyVoice(**kwargs)
|
||||||
|
elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
|
||||||
|
return CosyVoice2(**kwargs)
|
||||||
|
elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
|
||||||
|
return CosyVoice3(**kwargs)
|
||||||
|
else:
|
||||||
|
raise TypeError('No valid model type found!')
|
||||||
|
|||||||
@@ -12,25 +12,19 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Generator
|
||||||
|
import json
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import whisper
|
import whisper
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
import torchaudio.compliance.kaldi as kaldi
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
import torchaudio
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import inflect
|
import inflect
|
||||||
try:
|
from cosyvoice.utils.file_utils import logging, load_wav
|
||||||
import ttsfrd
|
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
||||||
use_ttsfrd = True
|
|
||||||
except ImportError:
|
|
||||||
print("failed to import ttsfrd, use WeTextProcessing instead")
|
|
||||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
|
||||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
|
||||||
use_ttsfrd = False
|
|
||||||
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
|
|
||||||
|
|
||||||
|
|
||||||
class CosyVoiceFrontEnd:
|
class CosyVoiceFrontEnd:
|
||||||
@@ -41,7 +35,6 @@ class CosyVoiceFrontEnd:
|
|||||||
campplus_model: str,
|
campplus_model: str,
|
||||||
speech_tokenizer_model: str,
|
speech_tokenizer_model: str,
|
||||||
spk2info: str = '',
|
spk2info: str = '',
|
||||||
instruct: bool = False,
|
|
||||||
allowed_special: str = 'all'):
|
allowed_special: str = 'all'):
|
||||||
self.tokenizer = get_tokenizer()
|
self.tokenizer = get_tokenizer()
|
||||||
self.feat_extractor = feat_extractor
|
self.feat_extractor = feat_extractor
|
||||||
@@ -50,83 +43,121 @@ class CosyVoiceFrontEnd:
|
|||||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
option.intra_op_num_threads = 1
|
option.intra_op_num_threads = 1
|
||||||
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
||||||
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
|
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
||||||
|
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
||||||
|
"CPUExecutionProvider"])
|
||||||
if os.path.exists(spk2info):
|
if os.path.exists(spk2info):
|
||||||
self.spk2info = torch.load(spk2info, map_location=self.device)
|
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
|
||||||
self.instruct = instruct
|
else:
|
||||||
|
self.spk2info = {}
|
||||||
self.allowed_special = allowed_special
|
self.allowed_special = allowed_special
|
||||||
self.inflect_parser = inflect.engine()
|
self.inflect_parser = inflect.engine()
|
||||||
self.use_ttsfrd = use_ttsfrd
|
# NOTE compatible when no text frontend tool is avaliable
|
||||||
if self.use_ttsfrd:
|
try:
|
||||||
|
import ttsfrd
|
||||||
self.frd = ttsfrd.TtsFrontendEngine()
|
self.frd = ttsfrd.TtsFrontendEngine()
|
||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
|
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
||||||
self.frd.set_lang_type('pinyin')
|
'failed to initialize ttsfrd resource'
|
||||||
self.frd.enable_pinyin_mix(True)
|
self.frd.set_lang_type('pinyinvg')
|
||||||
self.frd.set_breakmodel_index(1)
|
self.text_frontend = 'ttsfrd'
|
||||||
else:
|
logging.info('use ttsfrd frontend')
|
||||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
|
except:
|
||||||
self.en_tn_model = EnNormalizer()
|
try:
|
||||||
|
from wetext import Normalizer as ZhNormalizer
|
||||||
|
from wetext import Normalizer as EnNormalizer
|
||||||
|
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
|
||||||
|
self.en_tn_model = EnNormalizer()
|
||||||
|
self.text_frontend = 'wetext'
|
||||||
|
logging.info('use wetext frontend')
|
||||||
|
except:
|
||||||
|
self.text_frontend = ''
|
||||||
|
logging.info('no frontend is avaliable')
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_token(self, text):
|
def _extract_text_token(self, text):
|
||||||
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
if isinstance(text, Generator):
|
||||||
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
||||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
# NOTE add a dummy text_token_len for compatibility
|
||||||
return text_token, text_token_len
|
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
||||||
|
else:
|
||||||
|
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
||||||
|
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
||||||
|
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||||
|
return text_token, text_token_len
|
||||||
|
|
||||||
def _extract_speech_token(self, speech):
|
def _extract_text_token_generator(self, text_generator):
|
||||||
|
for text in text_generator:
|
||||||
|
text_token, _ = self._extract_text_token(text)
|
||||||
|
for i in range(text_token.shape[1]):
|
||||||
|
yield text_token[:, i: i + 1]
|
||||||
|
|
||||||
|
def _extract_speech_token(self, prompt_wav):
|
||||||
|
speech = load_wav(prompt_wav, 16000)
|
||||||
|
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
||||||
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
||||||
speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
|
speech_token = self.speech_tokenizer_session.run(None,
|
||||||
self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
{self.speech_tokenizer_session.get_inputs()[0].name:
|
||||||
|
feat.detach().cpu().numpy(),
|
||||||
|
self.speech_tokenizer_session.get_inputs()[1].name:
|
||||||
|
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
||||||
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
||||||
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||||
return speech_token, speech_token_len
|
return speech_token, speech_token_len
|
||||||
|
|
||||||
def _extract_spk_embedding(self, speech):
|
def _extract_spk_embedding(self, prompt_wav):
|
||||||
|
speech = load_wav(prompt_wav, 16000)
|
||||||
feat = kaldi.fbank(speech,
|
feat = kaldi.fbank(speech,
|
||||||
num_mel_bins=80,
|
num_mel_bins=80,
|
||||||
dither=0,
|
dither=0,
|
||||||
sample_frequency=16000)
|
sample_frequency=16000)
|
||||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||||
embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
embedding = self.campplus_session.run(None,
|
||||||
|
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||||
embedding = torch.tensor([embedding]).to(self.device)
|
embedding = torch.tensor([embedding]).to(self.device)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def _extract_speech_feat(self, speech):
|
def _extract_speech_feat(self, prompt_wav):
|
||||||
|
speech = load_wav(prompt_wav, 24000)
|
||||||
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
||||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||||
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
||||||
return speech_feat, speech_feat_len
|
return speech_feat, speech_feat_len
|
||||||
|
|
||||||
def text_normalize(self, text, split=True):
|
def text_normalize(self, text, split=True, text_frontend=True):
|
||||||
|
if isinstance(text, Generator):
|
||||||
|
logging.info('get tts_text generator, will skip text_normalize!')
|
||||||
|
return [text]
|
||||||
|
# NOTE skip text_frontend when ssml symbol in text
|
||||||
|
if '<|' in text and '|>' in text:
|
||||||
|
text_frontend = False
|
||||||
|
if text_frontend is False or text == '':
|
||||||
|
return [text] if split is True else text
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
if contains_chinese(text):
|
if self.text_frontend == 'ttsfrd':
|
||||||
if self.use_ttsfrd:
|
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
||||||
text = self.frd.get_frd_extra_info(text, 'input')
|
text = ''.join(texts)
|
||||||
else:
|
|
||||||
text = self.zh_tn_model.normalize(text)
|
|
||||||
text = text.replace("\n", "")
|
|
||||||
text = replace_blank(text)
|
|
||||||
text = replace_corner_mark(text)
|
|
||||||
text = text.replace(".", "、")
|
|
||||||
text = text.replace(" - ", ",")
|
|
||||||
text = remove_bracket(text)
|
|
||||||
text = re.sub(r'[,,]+$', '。', text)
|
|
||||||
texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
|
||||||
token_min_n=60, merge_len=20,
|
|
||||||
comma_split=False)]
|
|
||||||
else:
|
else:
|
||||||
if self.use_ttsfrd:
|
if contains_chinese(text):
|
||||||
text = self.frd.get_frd_extra_info(text, 'input')
|
if self.text_frontend == 'wetext':
|
||||||
|
text = self.zh_tn_model.normalize(text)
|
||||||
|
text = text.replace("\n", "")
|
||||||
|
text = replace_blank(text)
|
||||||
|
text = replace_corner_mark(text)
|
||||||
|
text = text.replace(".", "。")
|
||||||
|
text = text.replace(" - ", ",")
|
||||||
|
text = remove_bracket(text)
|
||||||
|
text = re.sub(r'[,,、]+$', '。', text)
|
||||||
|
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
||||||
|
token_min_n=60, merge_len=20, comma_split=False))
|
||||||
else:
|
else:
|
||||||
text = self.en_tn_model.normalize(text)
|
if self.text_frontend == 'wetext':
|
||||||
text = spell_out_number(text, self.inflect_parser)
|
text = self.en_tn_model.normalize(text)
|
||||||
texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
text = spell_out_number(text, self.inflect_parser)
|
||||||
token_min_n=60, merge_len=20,
|
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
||||||
comma_split=False)]
|
token_min_n=60, merge_len=20, comma_split=False))
|
||||||
if split is False:
|
texts = [i for i in texts if not is_only_punctuation(i)]
|
||||||
return text
|
return texts if split is True else text
|
||||||
return texts
|
|
||||||
|
|
||||||
def frontend_sft(self, tts_text, spk_id):
|
def frontend_sft(self, tts_text, spk_id):
|
||||||
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
||||||
@@ -134,23 +165,31 @@ class CosyVoiceFrontEnd:
|
|||||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||||
return model_input
|
return model_input
|
||||||
|
|
||||||
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
|
||||||
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
||||||
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
if zero_shot_spk_id == '':
|
||||||
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
||||||
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
|
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
|
||||||
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
|
||||||
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
if resample_rate == 24000:
|
||||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
# cosyvoice2, force speech_feat % speech_token = 2
|
||||||
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
||||||
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
||||||
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
||||||
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
embedding = self._extract_spk_embedding(prompt_wav)
|
||||||
'llm_embedding': embedding, 'flow_embedding': embedding}
|
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
||||||
|
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
||||||
|
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
||||||
|
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
||||||
|
'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||||
|
else:
|
||||||
|
model_input = {**self.spk2info[zero_shot_spk_id]}
|
||||||
|
model_input['text'] = tts_text_token
|
||||||
|
model_input['text_len'] = tts_text_token_len
|
||||||
return model_input
|
return model_input
|
||||||
|
|
||||||
def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
|
def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
|
||||||
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
|
model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
|
||||||
# in cross lingual mode, we remove prompt in llm
|
# in cross lingual mode, we remove prompt in llm
|
||||||
del model_input['prompt_text']
|
del model_input['prompt_text']
|
||||||
del model_input['prompt_text_len']
|
del model_input['prompt_text_len']
|
||||||
@@ -162,7 +201,24 @@ class CosyVoiceFrontEnd:
|
|||||||
model_input = self.frontend_sft(tts_text, spk_id)
|
model_input = self.frontend_sft(tts_text, spk_id)
|
||||||
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
||||||
del model_input['llm_embedding']
|
del model_input['llm_embedding']
|
||||||
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text)
|
||||||
model_input['prompt_text'] = instruct_text_token
|
model_input['prompt_text'] = instruct_text_token
|
||||||
model_input['prompt_text_len'] = instruct_text_token_len
|
model_input['prompt_text_len'] = instruct_text_token_len
|
||||||
return model_input
|
return model_input
|
||||||
|
|
||||||
|
def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
|
||||||
|
model_input = self.frontend_zero_shot(tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id)
|
||||||
|
del model_input['llm_prompt_speech_token']
|
||||||
|
del model_input['llm_prompt_speech_token_len']
|
||||||
|
return model_input
|
||||||
|
|
||||||
|
def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
|
||||||
|
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
|
||||||
|
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
|
||||||
|
embedding = self._extract_spk_embedding(prompt_wav)
|
||||||
|
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
||||||
|
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
||||||
|
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
||||||
|
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
||||||
|
'flow_embedding': embedding}
|
||||||
|
return model_input
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||||
|
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -11,50 +12,439 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
from typing import Generator
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from contextlib import nullcontext
|
||||||
|
import uuid
|
||||||
|
from cosyvoice.utils.common import fade_in_out
|
||||||
|
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||||
|
from cosyvoice.utils.common import TrtContextWrapper
|
||||||
|
|
||||||
|
|
||||||
class CosyVoiceModel:
|
class CosyVoiceModel:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
llm: torch.nn.Module,
|
llm: torch.nn.Module,
|
||||||
flow: torch.nn.Module,
|
flow: torch.nn.Module,
|
||||||
hift: torch.nn.Module):
|
hift: torch.nn.Module,
|
||||||
|
fp16: bool = False):
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.hift = hift
|
self.hift = hift
|
||||||
|
self.fp16 = fp16
|
||||||
|
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
||||||
|
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
||||||
|
self.token_overlap_len = 20
|
||||||
|
# mel fade in out
|
||||||
|
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
||||||
|
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
||||||
|
# hift cache
|
||||||
|
self.mel_cache_len = 20
|
||||||
|
self.source_cache_len = int(self.mel_cache_len * 256)
|
||||||
|
# speech fade in out
|
||||||
|
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||||
|
# rtf and decoding related
|
||||||
|
self.stream_scale_factor = 1
|
||||||
|
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||||
|
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
# dict used to store session related variable
|
||||||
|
self.tts_speech_token_dict = {}
|
||||||
|
self.llm_end_dict = {}
|
||||||
|
self.mel_overlap_dict = {}
|
||||||
|
self.flow_cache_dict = {}
|
||||||
|
self.hift_cache_dict = {}
|
||||||
|
self.silent_tokens = []
|
||||||
|
|
||||||
def load(self, llm_model, flow_model, hift_model):
|
def load(self, llm_model, flow_model, hift_model):
|
||||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
|
||||||
self.llm.to(self.device).eval()
|
self.llm.to(self.device).eval()
|
||||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
|
||||||
self.flow.to(self.device).eval()
|
self.flow.to(self.device).eval()
|
||||||
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
# in case hift_model is a hifigan model
|
||||||
|
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
|
||||||
|
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||||
self.hift.to(self.device).eval()
|
self.hift.to(self.device).eval()
|
||||||
|
|
||||||
def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
||||||
prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
|
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
||||||
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
self.llm.text_encoder = llm_text_encoder
|
||||||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
||||||
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
|
self.llm.llm = llm_llm
|
||||||
tts_speech_token = self.llm.inference(text=text.to(self.device),
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
text_len=text_len.to(self.device),
|
self.flow.encoder = flow_encoder
|
||||||
prompt_text=prompt_text.to(self.device),
|
|
||||||
prompt_text_len=prompt_text_len.to(self.device),
|
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
|
||||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||||
prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
|
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||||
embedding=llm_embedding.to(self.device),
|
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
||||||
beam_size=1,
|
del self.flow.decoder.estimator
|
||||||
sampling=25,
|
import tensorrt as trt
|
||||||
max_token_text_ratio=30,
|
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||||
min_token_text_ratio=3)
|
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
tts_mel = self.flow.inference(token=tts_speech_token,
|
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||||
token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
|
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||||
prompt_token=flow_prompt_speech_token.to(self.device),
|
|
||||||
prompt_token_len=flow_prompt_speech_token_len.to(self.device),
|
def get_trt_kwargs(self):
|
||||||
prompt_feat=prompt_speech_feat.to(self.device),
|
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||||
prompt_feat_len=prompt_speech_feat_len.to(self.device),
|
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
|
||||||
embedding=flow_embedding.to(self.device))
|
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
||||||
tts_speech = self.hift.inference(mel=tts_mel).cpu()
|
input_names = ["x", "mask", "mu", "cond"]
|
||||||
torch.cuda.empty_cache()
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
return {'tts_speech': tts_speech}
|
|
||||||
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||||
|
cur_silent_token_num, max_silent_token_num = 0, 5
|
||||||
|
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
||||||
|
if isinstance(text, Generator):
|
||||||
|
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
|
||||||
|
token_generator = self.llm.inference_bistream(text=text,
|
||||||
|
prompt_text=prompt_text.to(self.device),
|
||||||
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=llm_embedding.to(self.device))
|
||||||
|
else:
|
||||||
|
token_generator = self.llm.inference(text=text.to(self.device),
|
||||||
|
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_text=prompt_text.to(self.device),
|
||||||
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=llm_embedding.to(self.device),
|
||||||
|
uuid=uuid)
|
||||||
|
for i in token_generator:
|
||||||
|
if i in self.silent_tokens:
|
||||||
|
cur_silent_token_num += 1
|
||||||
|
if cur_silent_token_num > max_silent_token_num:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
cur_silent_token_num = 0
|
||||||
|
self.tts_speech_token_dict[uuid].append(i)
|
||||||
|
self.llm_end_dict[uuid] = True
|
||||||
|
|
||||||
|
def vc_job(self, source_speech_token, uuid):
|
||||||
|
self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
|
||||||
|
self.llm_end_dict[uuid] = True
|
||||||
|
|
||||||
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
||||||
|
with torch.cuda.amp.autocast(self.fp16):
|
||||||
|
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
|
||||||
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_token=prompt_token.to(self.device),
|
||||||
|
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_feat=prompt_feat.to(self.device),
|
||||||
|
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=embedding.to(self.device),
|
||||||
|
flow_cache=self.flow_cache_dict[uuid])
|
||||||
|
|
||||||
|
# mel overlap fade in out
|
||||||
|
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
||||||
|
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
||||||
|
# append hift cache
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||||
|
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||||
|
else:
|
||||||
|
hift_cache_source = torch.zeros(1, 1, 0)
|
||||||
|
# keep overlap mel and hift cache
|
||||||
|
if finalize is False:
|
||||||
|
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
||||||
|
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
||||||
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
|
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||||
|
'source': tts_source[:, :, -self.source_cache_len:],
|
||||||
|
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||||
|
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||||
|
else:
|
||||||
|
if speed != 1.0:
|
||||||
|
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||||
|
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||||
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
|
return tts_speech
|
||||||
|
|
||||||
|
def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
|
||||||
|
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
||||||
|
# this_uuid is used to track variables related to this inference thread
|
||||||
|
this_uuid = str(uuid.uuid1())
|
||||||
|
with self.lock:
|
||||||
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||||
|
self.hift_cache_dict[this_uuid] = None
|
||||||
|
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
||||||
|
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
||||||
|
if source_speech_token.shape[1] == 0:
|
||||||
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||||
|
else:
|
||||||
|
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
|
||||||
|
p.start()
|
||||||
|
if stream is True:
|
||||||
|
token_hop_len = self.token_min_hop_len
|
||||||
|
while True:
|
||||||
|
time.sleep(0.1)
|
||||||
|
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
||||||
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
||||||
|
.unsqueeze(dim=0)
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
uuid=this_uuid,
|
||||||
|
finalize=False)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
with self.lock:
|
||||||
|
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
||||||
|
# increase token_hop_len for better speech quality
|
||||||
|
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||||
|
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
||||||
|
break
|
||||||
|
p.join()
|
||||||
|
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||||
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
uuid=this_uuid,
|
||||||
|
finalize=True)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
else:
|
||||||
|
# deal with all tokens
|
||||||
|
p.join()
|
||||||
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
uuid=this_uuid,
|
||||||
|
finalize=True,
|
||||||
|
speed=speed)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
with self.lock:
|
||||||
|
self.tts_speech_token_dict.pop(this_uuid)
|
||||||
|
self.llm_end_dict.pop(this_uuid)
|
||||||
|
self.mel_overlap_dict.pop(this_uuid)
|
||||||
|
self.hift_cache_dict.pop(this_uuid)
|
||||||
|
self.flow_cache_dict.pop(this_uuid)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice2Model(CosyVoiceModel):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
llm: torch.nn.Module,
|
||||||
|
flow: torch.nn.Module,
|
||||||
|
hift: torch.nn.Module,
|
||||||
|
fp16: bool = False):
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.llm = llm
|
||||||
|
self.flow = flow
|
||||||
|
self.hift = hift
|
||||||
|
self.fp16 = fp16
|
||||||
|
# NOTE must matching training static_chunk_size
|
||||||
|
self.token_hop_len = 25
|
||||||
|
# NOTE increase token_hop_len incrementally to avoid duplicate inference
|
||||||
|
self.token_max_hop_len = 4 * self.token_hop_len
|
||||||
|
self.stream_scale_factor = 2
|
||||||
|
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||||
|
# hift cache
|
||||||
|
self.mel_cache_len = 8
|
||||||
|
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||||
|
# speech fade in out
|
||||||
|
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||||
|
# rtf and decoding related
|
||||||
|
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
# dict used to store session related variable
|
||||||
|
self.tts_speech_token_dict = {}
|
||||||
|
self.llm_end_dict = {}
|
||||||
|
self.hift_cache_dict = {}
|
||||||
|
self.silent_tokens = []
|
||||||
|
|
||||||
|
def load_jit(self, flow_encoder_model):
|
||||||
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
|
self.flow.encoder = flow_encoder
|
||||||
|
|
||||||
|
def load_vllm(self, model_dir):
|
||||||
|
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
|
||||||
|
from vllm import EngineArgs, LLMEngine
|
||||||
|
engine_args = EngineArgs(model=model_dir,
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
enable_prompt_embeds=True,
|
||||||
|
gpu_memory_utilization=0.2)
|
||||||
|
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
|
||||||
|
self.llm.lock = threading.Lock()
|
||||||
|
del self.llm.llm.model.model.layers
|
||||||
|
|
||||||
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||||
|
with torch.cuda.amp.autocast(self.fp16):
|
||||||
|
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
|
||||||
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_token=prompt_token.to(self.device),
|
||||||
|
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_feat=prompt_feat.to(self.device),
|
||||||
|
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=embedding.to(self.device),
|
||||||
|
streaming=stream,
|
||||||
|
finalize=finalize)
|
||||||
|
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||||
|
# append hift cache
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||||
|
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||||
|
else:
|
||||||
|
hift_cache_source = torch.zeros(1, 1, 0)
|
||||||
|
# keep overlap mel and hift cache
|
||||||
|
if finalize is False:
|
||||||
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
|
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||||
|
'source': tts_source[:, :, -self.source_cache_len:],
|
||||||
|
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||||
|
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||||
|
else:
|
||||||
|
if speed != 1.0:
|
||||||
|
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||||
|
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||||
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
|
return tts_speech
|
||||||
|
|
||||||
|
def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
|
||||||
|
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
||||||
|
# this_uuid is used to track variables related to this inference thread
|
||||||
|
this_uuid = str(uuid.uuid1())
|
||||||
|
with self.lock:
|
||||||
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||||
|
self.hift_cache_dict[this_uuid] = None
|
||||||
|
if source_speech_token.shape[1] == 0:
|
||||||
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||||
|
else:
|
||||||
|
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
|
||||||
|
p.start()
|
||||||
|
if stream is True:
|
||||||
|
token_offset = 0
|
||||||
|
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
|
||||||
|
while True:
|
||||||
|
time.sleep(0.1)
|
||||||
|
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
|
||||||
|
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
|
||||||
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
token_offset=token_offset,
|
||||||
|
uuid=this_uuid,
|
||||||
|
stream=stream,
|
||||||
|
finalize=False)
|
||||||
|
token_offset += this_token_hop_len
|
||||||
|
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
|
||||||
|
break
|
||||||
|
p.join()
|
||||||
|
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||||
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
token_offset=token_offset,
|
||||||
|
uuid=this_uuid,
|
||||||
|
finalize=True)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
else:
|
||||||
|
# deal with all tokens
|
||||||
|
p.join()
|
||||||
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
token_offset=0,
|
||||||
|
uuid=this_uuid,
|
||||||
|
finalize=True,
|
||||||
|
speed=speed)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
with self.lock:
|
||||||
|
self.tts_speech_token_dict.pop(this_uuid)
|
||||||
|
self.llm_end_dict.pop(this_uuid)
|
||||||
|
self.hift_cache_dict.pop(this_uuid)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice3Model(CosyVoice2Model):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
llm: torch.nn.Module,
|
||||||
|
flow: torch.nn.Module,
|
||||||
|
hift: torch.nn.Module,
|
||||||
|
fp16: bool = False):
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.llm = llm
|
||||||
|
self.flow = flow
|
||||||
|
self.hift = hift
|
||||||
|
self.fp16 = fp16
|
||||||
|
# NOTE must matching training static_chunk_size
|
||||||
|
self.token_hop_len = 25
|
||||||
|
# NOTE increase token_hop_len incrementally to avoid duplicate inference
|
||||||
|
self.token_max_hop_len = 4 * self.token_hop_len
|
||||||
|
self.stream_scale_factor = 2
|
||||||
|
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||||
|
# rtf and decoding related
|
||||||
|
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
# dict used to store session related variable
|
||||||
|
self.tts_speech_token_dict = {}
|
||||||
|
self.llm_end_dict = {}
|
||||||
|
self.hift_cache_dict = {}
|
||||||
|
# FSQ silent and breath token
|
||||||
|
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
|
||||||
|
|
||||||
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||||
|
with torch.cuda.amp.autocast(self.fp16):
|
||||||
|
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
|
||||||
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_token=prompt_token.to(self.device),
|
||||||
|
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_feat=prompt_feat.to(self.device),
|
||||||
|
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=embedding.to(self.device),
|
||||||
|
streaming=stream,
|
||||||
|
finalize=finalize)
|
||||||
|
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||||
|
# append mel cache
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
hift_cache_mel = self.hift_cache_dict[uuid]['mel']
|
||||||
|
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||||
|
self.hift_cache_dict[uuid]['mel'] = tts_mel
|
||||||
|
else:
|
||||||
|
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
|
||||||
|
if speed != 1.0:
|
||||||
|
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
|
||||||
|
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||||
|
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
|
||||||
|
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
|
||||||
|
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
|
||||||
|
return tts_speech
|
||||||
|
|||||||
@@ -14,14 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import json
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
from cosyvoice.utils.file_utils import read_lists
|
||||||
|
|
||||||
|
|
||||||
class Processor(IterableDataset):
|
class Processor(IterableDataset):
|
||||||
@@ -126,10 +125,10 @@ class DataList(IterableDataset):
|
|||||||
def Dataset(data_list_file,
|
def Dataset(data_list_file,
|
||||||
data_pipeline,
|
data_pipeline,
|
||||||
mode='train',
|
mode='train',
|
||||||
|
gan=False,
|
||||||
|
dpo=False,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
partition=True,
|
partition=True):
|
||||||
tts_file='',
|
|
||||||
prompt_utt2data=''):
|
|
||||||
""" Construct dataset from arguments
|
""" Construct dataset from arguments
|
||||||
|
|
||||||
We have two shuffle stage in the Dataset. The first is global
|
We have two shuffle stage in the Dataset. The first is global
|
||||||
@@ -141,20 +140,16 @@ def Dataset(data_list_file,
|
|||||||
tokenizer (BaseTokenizer): tokenizer to tokenize
|
tokenizer (BaseTokenizer): tokenizer to tokenize
|
||||||
partition(bool): whether to do data partition in terms of rank
|
partition(bool): whether to do data partition in terms of rank
|
||||||
"""
|
"""
|
||||||
assert mode in ['train', 'inference']
|
|
||||||
lists = read_lists(data_list_file)
|
lists = read_lists(data_list_file)
|
||||||
if mode == 'inference':
|
|
||||||
with open(tts_file) as f:
|
|
||||||
tts_data = json.load(f)
|
|
||||||
utt2lists = read_json_lists(prompt_utt2data)
|
|
||||||
# filter unnecessary file in inference mode
|
|
||||||
lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
|
|
||||||
dataset = DataList(lists,
|
dataset = DataList(lists,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
partition=partition)
|
partition=partition)
|
||||||
if mode == 'inference':
|
# map partial arg to padding func
|
||||||
# map partial arg tts_data in inference mode
|
for i in range(1, len(data_pipeline)):
|
||||||
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
if data_pipeline[i].func.__name__ == 'compute_fbank' and gan is True:
|
||||||
|
data_pipeline[i] = partial(data_pipeline[i], token_mel_ratio=0)
|
||||||
|
if data_pipeline[i].func.__name__ == 'padding':
|
||||||
|
data_pipeline[i] = partial(data_pipeline[i], gan=gan, dpo=dpo)
|
||||||
for func in data_pipeline:
|
for func in data_pipeline:
|
||||||
dataset = Processor(dataset, func, mode=mode)
|
dataset = Processor(dataset, func, mode=mode)
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@@ -16,17 +16,19 @@ import random
|
|||||||
|
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
import numpy as np
|
||||||
|
import whisper
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import pyworld as pw
|
||||||
|
from cosyvoice.utils.onnx import embedding_extractor, online_feature
|
||||||
|
|
||||||
torchaudio.set_audio_backend('soundfile')
|
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
||||||
|
|
||||||
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
|
|
||||||
|
|
||||||
|
|
||||||
def parquet_opener(data, mode='train', tts_data={}):
|
def parquet_opener(data, mode='train'):
|
||||||
""" Give url or local file, return file descriptor
|
""" Give url or local file, return file descriptor
|
||||||
Inplace operation.
|
Inplace operation.
|
||||||
|
|
||||||
@@ -40,20 +42,16 @@ def parquet_opener(data, mode='train', tts_data={}):
|
|||||||
assert 'src' in sample
|
assert 'src' in sample
|
||||||
url = sample['src']
|
url = sample['src']
|
||||||
try:
|
try:
|
||||||
df = pq.read_table(url).to_pandas()
|
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
||||||
for i in range(len(df)):
|
df = df.to_pandas()
|
||||||
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
for i in range(len(df)):
|
||||||
continue
|
sample.update(dict(df.loc[i]))
|
||||||
sample.update(dict(df.loc[i]))
|
|
||||||
if mode == 'train':
|
|
||||||
# NOTE do not return sample directly, must initialize a new dict
|
# NOTE do not return sample directly, must initialize a new dict
|
||||||
yield {**sample}
|
yield {**sample}
|
||||||
else:
|
|
||||||
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
|
||||||
yield {**sample, 'tts_index': index, 'tts_text': text}
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
||||||
|
|
||||||
|
|
||||||
def filter(data,
|
def filter(data,
|
||||||
max_length=10240,
|
max_length=10240,
|
||||||
min_length=10,
|
min_length=10,
|
||||||
@@ -84,6 +82,7 @@ def filter(data,
|
|||||||
"""
|
"""
|
||||||
for sample in data:
|
for sample in data:
|
||||||
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
||||||
|
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
||||||
del sample['audio_data']
|
del sample['audio_data']
|
||||||
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
||||||
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
||||||
@@ -95,7 +94,9 @@ def filter(data,
|
|||||||
continue
|
continue
|
||||||
if len(sample['text_token']) > token_max_length:
|
if len(sample['text_token']) > token_max_length:
|
||||||
continue
|
continue
|
||||||
if len(sample['speech_token']) == 0:
|
if online_feature is False and len(sample['speech_token']) == 0:
|
||||||
|
continue
|
||||||
|
if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
||||||
continue
|
continue
|
||||||
if num_frames != 0:
|
if num_frames != 0:
|
||||||
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
||||||
@@ -133,8 +134,30 @@ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
|||||||
yield sample
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
def truncate(data, truncate_length=24576, mode='train'):
|
||||||
|
""" Truncate data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable[{key, wav, label, sample_rate}]
|
||||||
|
truncate_length: truncate length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[{key, wav, label, sample_rate}]
|
||||||
|
"""
|
||||||
|
for sample in data:
|
||||||
|
waveform = sample['speech']
|
||||||
|
if waveform.shape[1] > truncate_length:
|
||||||
|
start = random.randint(0, waveform.shape[1] - truncate_length)
|
||||||
|
waveform = waveform[:, start: start + truncate_length]
|
||||||
|
else:
|
||||||
|
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
||||||
|
sample['speech'] = waveform
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank(data,
|
def compute_fbank(data,
|
||||||
feat_extractor,
|
feat_extractor,
|
||||||
|
num_frames=-1,
|
||||||
mode='train'):
|
mode='train'):
|
||||||
""" Extract fbank
|
""" Extract fbank
|
||||||
|
|
||||||
@@ -144,15 +167,58 @@ def compute_fbank(data,
|
|||||||
Returns:
|
Returns:
|
||||||
Iterable[{key, feat, label}]
|
Iterable[{key, feat, label}]
|
||||||
"""
|
"""
|
||||||
|
for sample in data:
|
||||||
|
assert 'sample_rate' in sample
|
||||||
|
assert 'speech' in sample
|
||||||
|
assert 'utt' in sample
|
||||||
|
assert 'text_token' in sample
|
||||||
|
# NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
|
||||||
|
if num_frames != -1:
|
||||||
|
index = int(np.ceil(sample['speech'].shape[1] / num_frames))
|
||||||
|
sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
|
||||||
|
sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
def compute_whisper_fbank(data, num_frames=-1, mode='train'):
|
||||||
|
""" Extract whisper fbank
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable[{key, wav, label, sample_rate}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[{key, feat, label}]
|
||||||
|
"""
|
||||||
|
for sample in data:
|
||||||
|
if num_frames != -1:
|
||||||
|
assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
|
||||||
|
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
|
||||||
|
sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
def compute_f0(data, sample_rate, hop_size, mode='train'):
|
||||||
|
""" Extract f0
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable[{key, wav, label, sample_rate}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[{key, feat, label}]
|
||||||
|
"""
|
||||||
|
frame_period = hop_size * 1000 / sample_rate
|
||||||
for sample in data:
|
for sample in data:
|
||||||
assert 'sample_rate' in sample
|
assert 'sample_rate' in sample
|
||||||
assert 'speech' in sample
|
assert 'speech' in sample
|
||||||
assert 'utt' in sample
|
assert 'utt' in sample
|
||||||
assert 'text_token' in sample
|
assert 'text_token' in sample
|
||||||
waveform = sample['speech']
|
waveform = sample['speech']
|
||||||
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
|
||||||
sample['speech_feat'] = mat
|
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
|
||||||
del sample['speech']
|
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
|
||||||
|
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
|
||||||
|
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
|
||||||
|
sample['pitch_feat'] = f0
|
||||||
yield sample
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
@@ -166,8 +232,13 @@ def parse_embedding(data, normalize, mode='train'):
|
|||||||
Iterable[{key, feat, label}]
|
Iterable[{key, feat, label}]
|
||||||
"""
|
"""
|
||||||
for sample in data:
|
for sample in data:
|
||||||
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
|
||||||
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
|
||||||
|
embedding = embedding_extractor.inference(sample['speech_16k'])
|
||||||
|
sample['spk_embedding'] = sample['utt_embedding'] = embedding
|
||||||
|
else:
|
||||||
|
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
||||||
|
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
||||||
if normalize:
|
if normalize:
|
||||||
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
||||||
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
||||||
@@ -188,8 +259,8 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
|||||||
for sample in data:
|
for sample in data:
|
||||||
assert 'text' in sample
|
assert 'text' in sample
|
||||||
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
||||||
if mode == 'inference':
|
if 'instruct' in sample:
|
||||||
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
|
||||||
yield sample
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
@@ -204,13 +275,14 @@ def shuffle(data, shuffle_size=10000, mode='train'):
|
|||||||
Iterable[{key, feat, label}]
|
Iterable[{key, feat, label}]
|
||||||
"""
|
"""
|
||||||
buf = []
|
buf = []
|
||||||
|
yield_size = int(shuffle_size / 2)
|
||||||
for sample in data:
|
for sample in data:
|
||||||
buf.append(sample)
|
buf.append(sample)
|
||||||
if len(buf) >= shuffle_size:
|
if len(buf) >= shuffle_size:
|
||||||
random.shuffle(buf)
|
random.shuffle(buf)
|
||||||
for x in buf:
|
for x in buf[:yield_size]:
|
||||||
yield x
|
yield x
|
||||||
buf = []
|
buf = buf[yield_size:]
|
||||||
# The sample left over
|
# The sample left over
|
||||||
random.shuffle(buf)
|
random.shuffle(buf)
|
||||||
for x in buf:
|
for x in buf:
|
||||||
@@ -297,18 +369,15 @@ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
|||||||
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
||||||
""" Wrapper for static/dynamic batch
|
""" Wrapper for static/dynamic batch
|
||||||
"""
|
"""
|
||||||
if mode == 'inference':
|
if batch_type == 'static':
|
||||||
return static_batch(data, 1)
|
return static_batch(data, batch_size)
|
||||||
|
elif batch_type == 'dynamic':
|
||||||
|
return dynamic_batch(data, max_frames_in_batch)
|
||||||
else:
|
else:
|
||||||
if batch_type == 'static':
|
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
||||||
return static_batch(data, batch_size)
|
|
||||||
elif batch_type == 'dynamic':
|
|
||||||
return dynamic_batch(data, max_frames_in_batch)
|
|
||||||
else:
|
|
||||||
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
|
||||||
|
|
||||||
|
|
||||||
def padding(data, use_spk_embedding, mode='train'):
|
def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
||||||
""" Padding the data into training data
|
""" Padding the data into training data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -319,49 +388,42 @@ def padding(data, use_spk_embedding, mode='train'):
|
|||||||
"""
|
"""
|
||||||
for sample in data:
|
for sample in data:
|
||||||
assert isinstance(sample, list)
|
assert isinstance(sample, list)
|
||||||
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
|
||||||
dtype=torch.int32)
|
batch = {}
|
||||||
order = torch.argsort(speech_feat_len, descending=True)
|
batch['utts'] = [sample[i]['utt'] for i in order]
|
||||||
|
batch['text'] = [sample[i]['text'] for i in order]
|
||||||
utts = [sample[i]['utt'] for i in order]
|
|
||||||
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
|
||||||
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
|
||||||
speech_token = pad_sequence(speech_token,
|
|
||||||
batch_first=True,
|
|
||||||
padding_value=0)
|
|
||||||
speech_feat = [sample[i]['speech_feat'] for i in order]
|
|
||||||
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
|
||||||
speech_feat = pad_sequence(speech_feat,
|
|
||||||
batch_first=True,
|
|
||||||
padding_value=0)
|
|
||||||
text = [sample[i]['text'] for i in order]
|
|
||||||
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
||||||
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
||||||
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
|
||||||
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
speech_feat = [sample[i]['speech_feat'] for i in order]
|
||||||
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
||||||
batch = {
|
batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
|
||||||
"utts": utts,
|
batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
||||||
"speech_token": speech_token,
|
batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
||||||
"speech_token_len": speech_token_len,
|
if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
|
||||||
"speech_feat": speech_feat,
|
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
|
||||||
"speech_feat_len": speech_feat_len,
|
batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
|
||||||
"text": text,
|
batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
|
||||||
"text_token": text_token,
|
if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
|
||||||
"text_token_len": text_token_len,
|
whisper_feat = [sample[i]['whisper_feat'] for i in order]
|
||||||
"utt_embedding": utt_embedding,
|
batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
|
||||||
"spk_embedding": spk_embedding,
|
batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
|
||||||
}
|
if torch.tensor(['speech_token' in sample[i] for i in order]).all():
|
||||||
if mode == 'inference':
|
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
||||||
tts_text = [sample[i]['tts_text'] for i in order]
|
batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
||||||
tts_index = [sample[i]['tts_index'] for i in order]
|
batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
|
||||||
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
if gan is True:
|
||||||
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
# in gan train, we need speech/pitch_feat
|
||||||
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
||||||
batch.update({'tts_text': tts_text,
|
batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
||||||
'tts_index': tts_index,
|
batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
|
||||||
'tts_text_token': tts_text_token,
|
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
||||||
'tts_text_token_len': tts_text_token_len})
|
batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
||||||
|
batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
|
||||||
|
if dpo is True:
|
||||||
|
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
|
||||||
|
batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
||||||
|
batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
|
||||||
if use_spk_embedding is True:
|
if use_spk_embedding is True:
|
||||||
batch["embedding"] = batch["spk_embedding"]
|
batch["embedding"] = batch["spk_embedding"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
176
cosyvoice/flow/DiT/dit.py
Normal file
176
cosyvoice/flow/DiT/dit.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
|
||||||
|
"""
|
||||||
|
ein notation:
|
||||||
|
b - batch
|
||||||
|
n - sequence
|
||||||
|
nt - text sequence
|
||||||
|
nw - raw wave length
|
||||||
|
d - dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import repeat
|
||||||
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
|
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||||
|
from cosyvoice.flow.DiT.modules import (
|
||||||
|
TimestepEmbedding,
|
||||||
|
ConvNeXtV2Block,
|
||||||
|
CausalConvPositionEmbedding,
|
||||||
|
DiTBlock,
|
||||||
|
AdaLayerNormZero_Final,
|
||||||
|
precompute_freqs_cis,
|
||||||
|
get_pos_embed_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Text embedding
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedding(nn.Module):
|
||||||
|
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
||||||
|
super().__init__()
|
||||||
|
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||||
|
|
||||||
|
if conv_layers > 0:
|
||||||
|
self.extra_modeling = True
|
||||||
|
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||||
|
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||||
|
self.text_blocks = nn.Sequential(
|
||||||
|
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.extra_modeling = False
|
||||||
|
|
||||||
|
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||||
|
batch, text_len = text.shape[0], text.shape[1]
|
||||||
|
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||||
|
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||||
|
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||||
|
|
||||||
|
if drop_text: # cfg for text
|
||||||
|
text = torch.zeros_like(text)
|
||||||
|
|
||||||
|
text = self.text_embed(text) # b n -> b n d
|
||||||
|
|
||||||
|
# possible extra modeling
|
||||||
|
if self.extra_modeling:
|
||||||
|
# sinus pos emb
|
||||||
|
batch_start = torch.zeros((batch,), dtype=torch.long)
|
||||||
|
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
||||||
|
text_pos_embed = self.freqs_cis[pos_idx]
|
||||||
|
text = text + text_pos_embed
|
||||||
|
|
||||||
|
# convnextv2 blocks
|
||||||
|
text = self.text_blocks(text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# noised input audio and context mixing embedding
|
||||||
|
|
||||||
|
|
||||||
|
class InputEmbedding(nn.Module):
|
||||||
|
def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
|
||||||
|
super().__init__()
|
||||||
|
spk_dim = 0 if spk_dim is None else spk_dim
|
||||||
|
self.spk_dim = spk_dim
|
||||||
|
self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
|
||||||
|
self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: float["b n d"],
|
||||||
|
cond: float["b n d"],
|
||||||
|
text_embed: float["b n d"],
|
||||||
|
spks: float["b d"],
|
||||||
|
):
|
||||||
|
to_cat = [x, cond, text_embed]
|
||||||
|
if self.spk_dim > 0:
|
||||||
|
spks = repeat(spks, "b c -> b t c", t=x.shape[1])
|
||||||
|
to_cat.append(spks)
|
||||||
|
|
||||||
|
x = self.proj(torch.cat(to_cat, dim=-1))
|
||||||
|
x = self.conv_pos_embed(x) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Transformer backbone using DiT blocks
|
||||||
|
|
||||||
|
|
||||||
|
class DiT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth=8,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.1,
|
||||||
|
ff_mult=4,
|
||||||
|
mel_dim=80,
|
||||||
|
mu_dim=None,
|
||||||
|
long_skip_connection=False,
|
||||||
|
spk_dim=None,
|
||||||
|
out_channels=None,
|
||||||
|
static_chunk_size=50,
|
||||||
|
num_decoding_left_chunks=2
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.time_embed = TimestepEmbedding(dim)
|
||||||
|
if mu_dim is None:
|
||||||
|
mu_dim = mel_dim
|
||||||
|
self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
|
||||||
|
|
||||||
|
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
||||||
|
)
|
||||||
|
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
||||||
|
|
||||||
|
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||||
|
self.proj_out = nn.Linear(dim, mel_dim)
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.static_chunk_size = static_chunk_size
|
||||||
|
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||||
|
|
||||||
|
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
mu = mu.transpose(1, 2)
|
||||||
|
cond = cond.transpose(1, 2)
|
||||||
|
spks = spks.unsqueeze(dim=1)
|
||||||
|
batch, seq_len = x.shape[0], x.shape[1]
|
||||||
|
if t.ndim == 0:
|
||||||
|
t = t.repeat(batch)
|
||||||
|
|
||||||
|
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||||
|
t = self.time_embed(t)
|
||||||
|
x = self.input_embed(x, cond, mu, spks.squeeze(1))
|
||||||
|
|
||||||
|
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||||
|
|
||||||
|
if self.long_skip_connection is not None:
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
if streaming is True:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
|
||||||
|
else:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
|
||||||
|
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
x = block(x, t, mask=attn_mask.bool(), rope=rope)
|
||||||
|
|
||||||
|
if self.long_skip_connection is not None:
|
||||||
|
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||||
|
|
||||||
|
x = self.norm_out(x, t)
|
||||||
|
output = self.proj_out(x).transpose(1, 2)
|
||||||
|
return output
|
||||||
616
cosyvoice/flow/DiT/modules.py
Normal file
616
cosyvoice/flow/DiT/modules.py
Normal file
@@ -0,0 +1,616 @@
|
|||||||
|
|
||||||
|
"""
|
||||||
|
ein notation:
|
||||||
|
b - batch
|
||||||
|
n - sequence
|
||||||
|
nt - text sequence
|
||||||
|
nw - raw wave length
|
||||||
|
d - dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Optional
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
|
# raw wav to mel spec
|
||||||
|
class MelSpec(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filter_length=1024,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
target_sample_rate=24_000,
|
||||||
|
normalize=False,
|
||||||
|
power=1,
|
||||||
|
norm=None,
|
||||||
|
center=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.n_mel_channels = n_mel_channels
|
||||||
|
|
||||||
|
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=target_sample_rate,
|
||||||
|
n_fft=filter_length,
|
||||||
|
win_length=win_length,
|
||||||
|
hop_length=hop_length,
|
||||||
|
n_mels=n_mel_channels,
|
||||||
|
power=power,
|
||||||
|
center=center,
|
||||||
|
normalized=normalize,
|
||||||
|
norm=norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
if len(inp.shape) == 3:
|
||||||
|
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
|
||||||
|
|
||||||
|
assert len(inp.shape) == 2
|
||||||
|
|
||||||
|
if self.dummy.device != inp.device:
|
||||||
|
self.to(inp.device)
|
||||||
|
|
||||||
|
mel = self.mel_stft(inp)
|
||||||
|
mel = mel.clamp(min=1e-5).log()
|
||||||
|
return mel
|
||||||
|
|
||||||
|
|
||||||
|
# sinusoidal position embedding
|
||||||
|
|
||||||
|
|
||||||
|
class SinusPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x, scale=1000):
|
||||||
|
device = x.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||||
|
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
# convolutional position embedding
|
||||||
|
|
||||||
|
|
||||||
|
class ConvPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, kernel_size=31, groups=16):
|
||||||
|
super().__init__()
|
||||||
|
assert kernel_size % 2 != 0
|
||||||
|
self.conv1d = nn.Sequential(
|
||||||
|
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||||
|
nn.Mish(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask[..., None]
|
||||||
|
x = x.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.conv1d(x)
|
||||||
|
out = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
out = out.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConvPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, kernel_size=31, groups=16):
|
||||||
|
super().__init__()
|
||||||
|
assert kernel_size % 2 != 0
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.conv1 = nn.Sequential(
|
||||||
|
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
|
||||||
|
nn.Mish(),
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Sequential(
|
||||||
|
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
|
||||||
|
nn.Mish(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask[..., None]
|
||||||
|
x = x.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
|
||||||
|
x = self.conv2(x)
|
||||||
|
out = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
out = out.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# rotary positional embedding related
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
||||||
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||||
|
# has some connection to NTK literature
|
||||||
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
||||||
|
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||||
|
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||||
|
freqs_cos = torch.cos(freqs) # real part
|
||||||
|
freqs_sin = torch.sin(freqs) # imaginary part
|
||||||
|
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
||||||
|
# length = length if isinstance(length, int) else length.max()
|
||||||
|
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
||||||
|
pos = (
|
||||||
|
start.unsqueeze(1)
|
||||||
|
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
||||||
|
)
|
||||||
|
# avoid extra long error.
|
||||||
|
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
# Global Response Normalization layer (Instance Normalization ?)
|
||||||
|
|
||||||
|
|
||||||
|
class GRN(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
||||||
|
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||||
|
return self.gamma * (x * Nx) + self.beta + x
|
||||||
|
|
||||||
|
|
||||||
|
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
||||||
|
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNeXtV2Block(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
intermediate_dim: int,
|
||||||
|
dilation: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
padding = (dilation * (7 - 1)) // 2
|
||||||
|
self.dwconv = nn.Conv1d(
|
||||||
|
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
||||||
|
) # depthwise conv
|
||||||
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||||
|
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.grn = GRN(intermediate_dim)
|
||||||
|
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = x.transpose(1, 2) # b n d -> b d n
|
||||||
|
x = self.dwconv(x)
|
||||||
|
x = x.transpose(1, 2) # b d n -> b n d
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.pwconv1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.grn(x)
|
||||||
|
x = self.pwconv2(x)
|
||||||
|
return residual + x
|
||||||
|
|
||||||
|
|
||||||
|
# AdaLayerNormZero
|
||||||
|
# return with modulated x for attn input, and params for later mlp modulation
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormZero(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(dim, dim * 6)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, emb=None):
|
||||||
|
emb = self.linear(self.silu(emb))
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
||||||
|
|
||||||
|
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||||
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||||
|
|
||||||
|
|
||||||
|
# AdaLayerNormZero for final layer
|
||||||
|
# return only with modulated x for attn input, cuz no more mlp modulation
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormZero_Final(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(dim, dim * 2)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
emb = self.linear(self.silu(emb))
|
||||||
|
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||||
|
|
||||||
|
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# FeedForward
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
dim_out = dim_out if dim_out is not None else dim
|
||||||
|
|
||||||
|
activation = nn.GELU(approximate=approximate)
|
||||||
|
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
||||||
|
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ff(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Attention with possible joint part
|
||||||
|
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: JointAttnProcessor | AttnProcessor,
|
||||||
|
dim: int,
|
||||||
|
heads: int = 8,
|
||||||
|
dim_head: int = 64,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
context_dim: Optional[int] = None, # if not None -> joint attention
|
||||||
|
context_pre_only=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.heads = heads
|
||||||
|
self.inner_dim = dim_head * heads
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
self.context_dim = context_dim
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, self.inner_dim)
|
||||||
|
self.to_k = nn.Linear(dim, self.inner_dim)
|
||||||
|
self.to_v = nn.Linear(dim, self.inner_dim)
|
||||||
|
|
||||||
|
if self.context_dim is not None:
|
||||||
|
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
||||||
|
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
||||||
|
if self.context_pre_only is not None:
|
||||||
|
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
||||||
|
|
||||||
|
self.to_out = nn.ModuleList([])
|
||||||
|
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
||||||
|
self.to_out.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
if self.context_pre_only is not None and not self.context_pre_only:
|
||||||
|
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: float["b n d"], # noised input x # noqa: F722
|
||||||
|
c: float["b n d"] = None, # context c # noqa: F722
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
rope=None, # rotary position embedding for x
|
||||||
|
c_rope=None, # rotary position embedding for c
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if c is not None:
|
||||||
|
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
||||||
|
else:
|
||||||
|
return self.processor(self, x, mask=mask, rope=rope)
|
||||||
|
|
||||||
|
|
||||||
|
# Attention processor
|
||||||
|
|
||||||
|
|
||||||
|
class AttnProcessor:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
x: float["b n d"], # noised input x # noqa: F722
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
rope=None, # rotary position embedding
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
query = attn.to_q(x)
|
||||||
|
key = attn.to_k(x)
|
||||||
|
value = attn.to_v(x)
|
||||||
|
|
||||||
|
# apply rotary position embedding
|
||||||
|
if rope is not None:
|
||||||
|
freqs, xpos_scale = rope
|
||||||
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||||
|
|
||||||
|
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||||
|
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||||
|
if mask is not None:
|
||||||
|
attn_mask = mask
|
||||||
|
if attn_mask.dim() == 2:
|
||||||
|
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||||
|
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||||
|
else:
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||||
|
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
x = x.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
x = attn.to_out[0](x)
|
||||||
|
# dropout
|
||||||
|
x = attn.to_out[1](x)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask.dim() == 2:
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
mask = mask[:, 0, -1].unsqueeze(-1)
|
||||||
|
x = x.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Joint Attention processor for MM-DiT
|
||||||
|
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||||
|
|
||||||
|
|
||||||
|
class JointAttnProcessor:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
x: float["b n d"], # noised input x # noqa: F722
|
||||||
|
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
rope=None, # rotary position embedding for x
|
||||||
|
c_rope=None, # rotary position embedding for c
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
batch_size = c.shape[0]
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
query = attn.to_q(x)
|
||||||
|
key = attn.to_k(x)
|
||||||
|
value = attn.to_v(x)
|
||||||
|
|
||||||
|
# `context` projections.
|
||||||
|
c_query = attn.to_q_c(c)
|
||||||
|
c_key = attn.to_k_c(c)
|
||||||
|
c_value = attn.to_v_c(c)
|
||||||
|
|
||||||
|
# apply rope for context and noised input independently
|
||||||
|
if rope is not None:
|
||||||
|
freqs, xpos_scale = rope
|
||||||
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||||
|
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||||
|
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||||
|
if c_rope is not None:
|
||||||
|
freqs, xpos_scale = c_rope
|
||||||
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||||
|
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
||||||
|
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
query = torch.cat([query, c_query], dim=1)
|
||||||
|
key = torch.cat([key, c_key], dim=1)
|
||||||
|
value = torch.cat([value, c_value], dim=1)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||||
|
if mask is not None:
|
||||||
|
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
||||||
|
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||||
|
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||||
|
else:
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||||
|
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
x = x.to(query.dtype)
|
||||||
|
|
||||||
|
# Split the attention outputs.
|
||||||
|
x, c = (
|
||||||
|
x[:, : residual.shape[1]],
|
||||||
|
x[:, residual.shape[1]:],
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
x = attn.to_out[0](x)
|
||||||
|
# dropout
|
||||||
|
x = attn.to_out[1](x)
|
||||||
|
if not attn.context_pre_only:
|
||||||
|
c = attn.to_out_c(c)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
x = x.masked_fill(~mask, 0.0)
|
||||||
|
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
||||||
|
|
||||||
|
return x, c
|
||||||
|
|
||||||
|
|
||||||
|
# DiT Block
|
||||||
|
|
||||||
|
|
||||||
|
class DiTBlock(nn.Module):
|
||||||
|
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn_norm = AdaLayerNormZero(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
processor=AttnProcessor(),
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||||
|
|
||||||
|
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
||||||
|
# pre-norm & modulation for attention input
|
||||||
|
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
||||||
|
|
||||||
|
# process attention output for input x
|
||||||
|
x = x + gate_msa.unsqueeze(1) * attn_output
|
||||||
|
|
||||||
|
ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||||
|
ff_output = self.ff(ff_norm)
|
||||||
|
x = x + gate_mlp.unsqueeze(1) * ff_output
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# MMDiT Block https://arxiv.org/abs/2403.03206
|
||||||
|
|
||||||
|
|
||||||
|
class MMDiTBlock(nn.Module):
|
||||||
|
r"""
|
||||||
|
modified from diffusers/src/diffusers/models/attention.py
|
||||||
|
|
||||||
|
notes.
|
||||||
|
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
||||||
|
_x: noised input related. (right part)
|
||||||
|
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
|
||||||
|
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
||||||
|
self.attn_norm_x = AdaLayerNormZero(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
processor=JointAttnProcessor(),
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
dropout=dropout,
|
||||||
|
context_dim=dim,
|
||||||
|
context_pre_only=context_pre_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context_pre_only:
|
||||||
|
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||||
|
else:
|
||||||
|
self.ff_norm_c = None
|
||||||
|
self.ff_c = None
|
||||||
|
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||||
|
|
||||||
|
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
||||||
|
# pre-norm & modulation for attention input
|
||||||
|
if self.context_pre_only:
|
||||||
|
norm_c = self.attn_norm_c(c, t)
|
||||||
|
else:
|
||||||
|
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
|
||||||
|
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
||||||
|
|
||||||
|
# process attention output for context c
|
||||||
|
if self.context_pre_only:
|
||||||
|
c = None
|
||||||
|
else: # if not last layer
|
||||||
|
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
||||||
|
|
||||||
|
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||||
|
c_ff_output = self.ff_c(norm_c)
|
||||||
|
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
||||||
|
|
||||||
|
# process attention output for input x
|
||||||
|
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
||||||
|
|
||||||
|
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
||||||
|
x_ff_output = self.ff_x(norm_x)
|
||||||
|
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
||||||
|
|
||||||
|
return c, x
|
||||||
|
|
||||||
|
|
||||||
|
# time step conditioning embedding
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, freq_embed_dim=256):
|
||||||
|
super().__init__()
|
||||||
|
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||||
|
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||||
|
|
||||||
|
def forward(self, timestep: float["b"]): # noqa: F821
|
||||||
|
time_hidden = self.time_embed(timestep)
|
||||||
|
time_hidden = time_hidden.to(timestep.dtype)
|
||||||
|
time = self.time_mlp(time_hidden) # b d
|
||||||
|
return time
|
||||||
286
cosyvoice/flow/decoder.py
Executable file → Normal file
286
cosyvoice/flow/decoder.py
Executable file → Normal file
@@ -11,13 +11,80 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from einops import pack, rearrange, repeat
|
from einops import pack, rearrange, repeat
|
||||||
|
from cosyvoice.utils.common import mask_to_bias
|
||||||
|
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||||
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
||||||
from matcha.models.components.transformer import BasicTransformerBlock
|
from matcha.models.components.transformer import BasicTransformerBlock
|
||||||
|
|
||||||
|
|
||||||
|
class Transpose(torch.nn.Module):
|
||||||
|
def __init__(self, dim0: int, dim1: int):
|
||||||
|
super().__init__()
|
||||||
|
self.dim0 = dim0
|
||||||
|
self.dim1 = dim1
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = torch.transpose(x, self.dim0, self.dim1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1d(torch.nn.Conv1d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
padding_mode: str = 'zeros',
|
||||||
|
device=None,
|
||||||
|
dtype=None
|
||||||
|
) -> None:
|
||||||
|
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
||||||
|
kernel_size, stride,
|
||||||
|
padding=0, dilation=dilation,
|
||||||
|
groups=groups, bias=bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
assert stride == 1
|
||||||
|
self.causal_padding = kernel_size - 1
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||||
|
x = super(CausalConv1d, self).forward(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalBlock1D(Block1D):
|
||||||
|
def __init__(self, dim: int, dim_out: int):
|
||||||
|
super(CausalBlock1D, self).__init__(dim, dim_out)
|
||||||
|
self.block = torch.nn.Sequential(
|
||||||
|
CausalConv1d(dim, dim_out, 3),
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.LayerNorm(dim_out),
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.Mish(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
output = self.block(x * mask)
|
||||||
|
return output * mask
|
||||||
|
|
||||||
|
|
||||||
|
class CausalResnetBlock1D(ResnetBlock1D):
|
||||||
|
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
||||||
|
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
||||||
|
self.block1 = CausalBlock1D(dim, dim_out)
|
||||||
|
self.block2 = CausalBlock1D(dim_out, dim_out)
|
||||||
|
|
||||||
|
|
||||||
class ConditionalDecoder(nn.Module):
|
class ConditionalDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -74,7 +141,7 @@ class ConditionalDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||||
|
|
||||||
for i in range(num_mid_blocks):
|
for _ in range(num_mid_blocks):
|
||||||
input_channel = channels[-1]
|
input_channel = channels[-1]
|
||||||
out_channels = channels[-1]
|
out_channels = channels[-1]
|
||||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||||
@@ -126,7 +193,6 @@ class ConditionalDecoder(nn.Module):
|
|||||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||||
self.initialize_weights()
|
self.initialize_weights()
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(self):
|
def initialize_weights(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv1d):
|
if isinstance(m, nn.Conv1d):
|
||||||
@@ -141,7 +207,7 @@ class ConditionalDecoder(nn.Module):
|
|||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||||
"""Forward pass of the UNet1DConditional model.
|
"""Forward pass of the UNet1DConditional model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -159,7 +225,7 @@ class ConditionalDecoder(nn.Module):
|
|||||||
_type_: _description_
|
_type_: _description_
|
||||||
"""
|
"""
|
||||||
|
|
||||||
t = self.time_embeddings(t)
|
t = self.time_embeddings(t).to(t.dtype)
|
||||||
t = self.time_mlp(t)
|
t = self.time_mlp(t)
|
||||||
|
|
||||||
x = pack([x, mu], "b * t")[0]
|
x = pack([x, mu], "b * t")[0]
|
||||||
@@ -176,7 +242,8 @@ class ConditionalDecoder(nn.Module):
|
|||||||
mask_down = masks[-1]
|
mask_down = masks[-1]
|
||||||
x = resnet(x, mask_down, t)
|
x = resnet(x, mask_down, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x = transformer_block(
|
x = transformer_block(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -193,7 +260,8 @@ class ConditionalDecoder(nn.Module):
|
|||||||
for resnet, transformer_blocks in self.mid_blocks:
|
for resnet, transformer_blocks in self.mid_blocks:
|
||||||
x = resnet(x, mask_mid, t)
|
x = resnet(x, mask_mid, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x = transformer_block(
|
x = transformer_block(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -208,7 +276,211 @@ class ConditionalDecoder(nn.Module):
|
|||||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||||
x = resnet(x, mask_up, t)
|
x = resnet(x, mask_up, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
|
for transformer_block in transformer_blocks:
|
||||||
|
x = transformer_block(
|
||||||
|
hidden_states=x,
|
||||||
|
attention_mask=attn_mask,
|
||||||
|
timestep=t,
|
||||||
|
)
|
||||||
|
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||||
|
x = upsample(x * mask_up)
|
||||||
|
x = self.final_block(x, mask_up)
|
||||||
|
output = self.final_proj(x * mask_up)
|
||||||
|
return output * mask
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConditionalDecoder(ConditionalDecoder):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
channels=(256, 256),
|
||||||
|
dropout=0.05,
|
||||||
|
attention_head_dim=64,
|
||||||
|
n_blocks=1,
|
||||||
|
num_mid_blocks=2,
|
||||||
|
num_heads=4,
|
||||||
|
act_fn="snake",
|
||||||
|
static_chunk_size=50,
|
||||||
|
num_decoding_left_chunks=2,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This decoder requires an input with the same shape of the target. So, if your text content
|
||||||
|
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
||||||
|
"""
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
channels = tuple(channels)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||||
|
time_embed_dim = channels[0] * 4
|
||||||
|
self.time_mlp = TimestepEmbedding(
|
||||||
|
in_channels=in_channels,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
act_fn="silu",
|
||||||
|
)
|
||||||
|
self.static_chunk_size = static_chunk_size
|
||||||
|
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||||
|
self.down_blocks = nn.ModuleList([])
|
||||||
|
self.mid_blocks = nn.ModuleList([])
|
||||||
|
self.up_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
output_channel = in_channels
|
||||||
|
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = channels[i]
|
||||||
|
is_last = i == len(channels) - 1
|
||||||
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||||
|
transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock(
|
||||||
|
dim=output_channel,
|
||||||
|
num_attention_heads=num_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn=act_fn,
|
||||||
|
)
|
||||||
|
for _ in range(n_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
downsample = (
|
||||||
|
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
|
||||||
|
)
|
||||||
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||||
|
|
||||||
|
for _ in range(num_mid_blocks):
|
||||||
|
input_channel = channels[-1]
|
||||||
|
out_channels = channels[-1]
|
||||||
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||||
|
|
||||||
|
transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock(
|
||||||
|
dim=output_channel,
|
||||||
|
num_attention_heads=num_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn=act_fn,
|
||||||
|
)
|
||||||
|
for _ in range(n_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
||||||
|
|
||||||
|
channels = channels[::-1] + (channels[0],)
|
||||||
|
for i in range(len(channels) - 1):
|
||||||
|
input_channel = channels[i] * 2
|
||||||
|
output_channel = channels[i + 1]
|
||||||
|
is_last = i == len(channels) - 2
|
||||||
|
resnet = CausalResnetBlock1D(
|
||||||
|
dim=input_channel,
|
||||||
|
dim_out=output_channel,
|
||||||
|
time_emb_dim=time_embed_dim,
|
||||||
|
)
|
||||||
|
transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock(
|
||||||
|
dim=output_channel,
|
||||||
|
num_attention_heads=num_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
activation_fn=act_fn,
|
||||||
|
)
|
||||||
|
for _ in range(n_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
upsample = (
|
||||||
|
Upsample1D(output_channel, use_conv_transpose=True)
|
||||||
|
if not is_last
|
||||||
|
else CausalConv1d(output_channel, output_channel, 3)
|
||||||
|
)
|
||||||
|
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
||||||
|
self.final_block = CausalBlock1D(channels[-1], channels[-1])
|
||||||
|
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||||
|
self.initialize_weights()
|
||||||
|
|
||||||
|
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||||
|
"""Forward pass of the UNet1DConditional model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||||
|
mask (_type_): shape (batch_size, 1, time)
|
||||||
|
t (_type_): shape (batch_size)
|
||||||
|
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||||
|
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: _description_
|
||||||
|
ValueError: _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
t = self.time_embeddings(t).to(t.dtype)
|
||||||
|
t = self.time_mlp(t)
|
||||||
|
|
||||||
|
x = pack([x, mu], "b * t")[0]
|
||||||
|
|
||||||
|
if spks is not None:
|
||||||
|
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||||
|
x = pack([x, spks], "b * t")[0]
|
||||||
|
if cond is not None:
|
||||||
|
x = pack([x, cond], "b * t")[0]
|
||||||
|
|
||||||
|
hiddens = []
|
||||||
|
masks = [mask]
|
||||||
|
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||||
|
mask_down = masks[-1]
|
||||||
|
x = resnet(x, mask_down, t)
|
||||||
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
|
if streaming is True:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||||
|
else:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
|
for transformer_block in transformer_blocks:
|
||||||
|
x = transformer_block(
|
||||||
|
hidden_states=x,
|
||||||
|
attention_mask=attn_mask,
|
||||||
|
timestep=t,
|
||||||
|
)
|
||||||
|
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||||
|
hiddens.append(x) # Save hidden states for skip connections
|
||||||
|
x = downsample(x * mask_down)
|
||||||
|
masks.append(mask_down[:, :, ::2])
|
||||||
|
masks = masks[:-1]
|
||||||
|
mask_mid = masks[-1]
|
||||||
|
|
||||||
|
for resnet, transformer_blocks in self.mid_blocks:
|
||||||
|
x = resnet(x, mask_mid, t)
|
||||||
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
|
if streaming is True:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||||
|
else:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
|
for transformer_block in transformer_blocks:
|
||||||
|
x = transformer_block(
|
||||||
|
hidden_states=x,
|
||||||
|
attention_mask=attn_mask,
|
||||||
|
timestep=t,
|
||||||
|
)
|
||||||
|
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||||
|
|
||||||
|
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||||
|
mask_up = masks.pop()
|
||||||
|
skip = hiddens.pop()
|
||||||
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||||
|
x = resnet(x, mask_up, t)
|
||||||
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
|
if streaming is True:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||||
|
else:
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x = transformer_block(
|
x = transformer_block(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import os, logging
|
||||||
import random
|
import random
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
import torch
|
import torch
|
||||||
@@ -19,6 +19,7 @@ import torch.nn as nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from cosyvoice.utils.mask import make_pad_mask
|
from cosyvoice.utils.mask import make_pad_mask
|
||||||
|
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
|
||||||
|
|
||||||
|
|
||||||
class MaskedDiffWithXvec(torch.nn.Module):
|
class MaskedDiffWithXvec(torch.nn.Module):
|
||||||
@@ -33,13 +34,15 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
encoder: torch.nn.Module = None,
|
encoder: torch.nn.Module = None,
|
||||||
length_regulator: torch.nn.Module = None,
|
length_regulator: torch.nn.Module = None,
|
||||||
decoder: torch.nn.Module = None,
|
decoder: torch.nn.Module = None,
|
||||||
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||||
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||||
|
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||||
|
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||||
|
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.decoder_conf = decoder_conf
|
self.decoder_conf = decoder_conf
|
||||||
self.mel_feat_conf = mel_feat_conf
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.output_type = output_type
|
self.output_type = output_type
|
||||||
self.input_frame_rate = input_frame_rate
|
self.input_frame_rate = input_frame_rate
|
||||||
@@ -86,7 +89,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
conds = conds.transpose(1, 2)
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
mask = (~make_pad_mask(feat_len)).to(h)
|
mask = (~make_pad_mask(feat_len)).to(h)
|
||||||
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
# NOTE this is unnecessary, feat/h already same shape
|
||||||
loss, _ = self.decoder.compute_loss(
|
loss, _ = self.decoder.compute_loss(
|
||||||
feat.transpose(1, 2).contiguous(),
|
feat.transpose(1, 2).contiguous(),
|
||||||
mask.unsqueeze(1),
|
mask.unsqueeze(1),
|
||||||
@@ -104,7 +107,142 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
prompt_token_len,
|
prompt_token_len,
|
||||||
prompt_feat,
|
prompt_feat,
|
||||||
prompt_feat_len,
|
prompt_feat_len,
|
||||||
embedding):
|
embedding,
|
||||||
|
flow_cache):
|
||||||
|
assert token.shape[0] == 1
|
||||||
|
# xvec projection
|
||||||
|
embedding = F.normalize(embedding, dim=1)
|
||||||
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
|
|
||||||
|
# concat speech token and prompt speech token
|
||||||
|
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
||||||
|
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||||
|
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||||
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||||
|
|
||||||
|
# text encode
|
||||||
|
h, h_lengths = self.encoder(token, token_len)
|
||||||
|
h = self.encoder_proj(h)
|
||||||
|
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
||||||
|
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
||||||
|
|
||||||
|
# get conditions
|
||||||
|
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||||
|
conds[:, :mel_len1] = prompt_feat
|
||||||
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
|
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||||
|
feat, flow_cache = self.decoder(
|
||||||
|
mu=h.transpose(1, 2).contiguous(),
|
||||||
|
mask=mask.unsqueeze(1),
|
||||||
|
spks=embedding,
|
||||||
|
cond=conds,
|
||||||
|
n_timesteps=10,
|
||||||
|
prompt_len=mel_len1,
|
||||||
|
cache=flow_cache
|
||||||
|
)
|
||||||
|
feat = feat[:, :, mel_len1:]
|
||||||
|
assert feat.shape[2] == mel_len2
|
||||||
|
return feat.float(), flow_cache
|
||||||
|
|
||||||
|
|
||||||
|
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
input_size: int = 512,
|
||||||
|
output_size: int = 80,
|
||||||
|
spk_embed_dim: int = 192,
|
||||||
|
output_type: str = "mel",
|
||||||
|
vocab_size: int = 4096,
|
||||||
|
input_frame_rate: int = 50,
|
||||||
|
only_mask_loss: bool = True,
|
||||||
|
token_mel_ratio: int = 2,
|
||||||
|
pre_lookahead_len: int = 3,
|
||||||
|
encoder: torch.nn.Module = None,
|
||||||
|
decoder: torch.nn.Module = None,
|
||||||
|
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||||
|
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||||
|
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||||
|
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||||
|
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = output_size
|
||||||
|
self.decoder_conf = decoder_conf
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.output_type = output_type
|
||||||
|
self.input_frame_rate = input_frame_rate
|
||||||
|
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||||
|
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
||||||
|
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
||||||
|
self.decoder = decoder
|
||||||
|
self.only_mask_loss = only_mask_loss
|
||||||
|
self.token_mel_ratio = token_mel_ratio
|
||||||
|
self.pre_lookahead_len = pre_lookahead_len
|
||||||
|
if online_feature is True:
|
||||||
|
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: dict,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, Optional[torch.Tensor]]:
|
||||||
|
if 'speech_token' not in batch:
|
||||||
|
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||||
|
else:
|
||||||
|
token = batch['speech_token'].to(device)
|
||||||
|
token_len = batch['speech_token_len'].to(device)
|
||||||
|
feat = batch['speech_feat'].to(device)
|
||||||
|
feat_len = batch['speech_feat_len'].to(device)
|
||||||
|
embedding = batch['embedding'].to(device)
|
||||||
|
|
||||||
|
# NOTE unified training, static_chunk_size > 0 or = 0
|
||||||
|
streaming = True if random.random() < 0.5 else False
|
||||||
|
|
||||||
|
# xvec projection
|
||||||
|
embedding = F.normalize(embedding, dim=1)
|
||||||
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
|
|
||||||
|
# concat text and prompt_text
|
||||||
|
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
||||||
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||||
|
|
||||||
|
# text encode
|
||||||
|
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
||||||
|
h = self.encoder_proj(h)
|
||||||
|
|
||||||
|
# get conditions
|
||||||
|
conds = torch.zeros(feat.shape, device=token.device)
|
||||||
|
for i, j in enumerate(feat_len):
|
||||||
|
if random.random() < 0.5:
|
||||||
|
continue
|
||||||
|
index = random.randint(0, int(0.3 * j))
|
||||||
|
conds[i, :index] = feat[i, :index]
|
||||||
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
|
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
||||||
|
loss, _ = self.decoder.compute_loss(
|
||||||
|
feat.transpose(1, 2).contiguous(),
|
||||||
|
mask.unsqueeze(1),
|
||||||
|
h.transpose(1, 2).contiguous(),
|
||||||
|
embedding,
|
||||||
|
cond=conds,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
return {'loss': loss}
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference(self,
|
||||||
|
token,
|
||||||
|
token_len,
|
||||||
|
prompt_token,
|
||||||
|
prompt_token_len,
|
||||||
|
prompt_feat,
|
||||||
|
prompt_feat_len,
|
||||||
|
embedding,
|
||||||
|
streaming,
|
||||||
|
finalize):
|
||||||
assert token.shape[0] == 1
|
assert token.shape[0] == 1
|
||||||
# xvec projection
|
# xvec projection
|
||||||
embedding = F.normalize(embedding, dim=1)
|
embedding = F.normalize(embedding, dim=1)
|
||||||
@@ -112,30 +250,194 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
|
|
||||||
# concat text and prompt_text
|
# concat text and prompt_text
|
||||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
|
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||||
|
|
||||||
# text encode
|
# text encode
|
||||||
h, h_lengths = self.encoder(token, token_len)
|
if finalize is True:
|
||||||
|
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
||||||
|
else:
|
||||||
|
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
||||||
|
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
||||||
|
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||||
h = self.encoder_proj(h)
|
h = self.encoder_proj(h)
|
||||||
feat_len = (token_len / 50 * 22050 / 256).int()
|
|
||||||
h, h_lengths = self.length_regulator(h, feat_len)
|
|
||||||
|
|
||||||
# get conditions
|
# get conditions
|
||||||
conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
|
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||||
if prompt_feat.shape[1] != 0:
|
conds[:, :mel_len1] = prompt_feat
|
||||||
for i, j in enumerate(prompt_feat_len):
|
|
||||||
conds[i, :j] = prompt_feat[i]
|
|
||||||
conds = conds.transpose(1, 2)
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
mask = (~make_pad_mask(feat_len)).to(h)
|
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||||
feat = self.decoder(
|
feat, _ = self.decoder(
|
||||||
mu=h.transpose(1, 2).contiguous(),
|
mu=h.transpose(1, 2).contiguous(),
|
||||||
mask=mask.unsqueeze(1),
|
mask=mask.unsqueeze(1),
|
||||||
spks=embedding,
|
spks=embedding,
|
||||||
cond=conds,
|
cond=conds,
|
||||||
n_timesteps=10
|
n_timesteps=10,
|
||||||
|
streaming=streaming
|
||||||
)
|
)
|
||||||
if prompt_feat.shape[1] != 0:
|
feat = feat[:, :, mel_len1:]
|
||||||
feat = feat[:, :, prompt_feat.shape[1]:]
|
assert feat.shape[2] == mel_len2
|
||||||
return feat
|
return feat.float(), None
|
||||||
|
|
||||||
|
|
||||||
|
class CausalMaskedDiffWithDiT(torch.nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
input_size: int = 512,
|
||||||
|
output_size: int = 80,
|
||||||
|
spk_embed_dim: int = 192,
|
||||||
|
output_type: str = "mel",
|
||||||
|
vocab_size: int = 4096,
|
||||||
|
input_frame_rate: int = 50,
|
||||||
|
only_mask_loss: bool = True,
|
||||||
|
token_mel_ratio: int = 2,
|
||||||
|
pre_lookahead_len: int = 3,
|
||||||
|
pre_lookahead_layer: torch.nn.Module = None,
|
||||||
|
decoder: torch.nn.Module = None,
|
||||||
|
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||||
|
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||||
|
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||||
|
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||||
|
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = output_size
|
||||||
|
self.decoder_conf = decoder_conf
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.output_type = output_type
|
||||||
|
self.input_frame_rate = input_frame_rate
|
||||||
|
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||||
|
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
||||||
|
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
||||||
|
self.pre_lookahead_len = pre_lookahead_len
|
||||||
|
self.pre_lookahead_layer = pre_lookahead_layer
|
||||||
|
self.decoder = decoder
|
||||||
|
self.only_mask_loss = only_mask_loss
|
||||||
|
self.token_mel_ratio = token_mel_ratio
|
||||||
|
if online_feature is True:
|
||||||
|
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: dict,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, Optional[torch.Tensor]]:
|
||||||
|
if 'speech_token' not in batch:
|
||||||
|
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||||
|
else:
|
||||||
|
token = batch['speech_token'].to(device)
|
||||||
|
token_len = batch['speech_token_len'].to(device)
|
||||||
|
feat = batch['speech_feat'].to(device)
|
||||||
|
feat_len = batch['speech_feat_len'].to(device)
|
||||||
|
embedding = batch['embedding'].to(device)
|
||||||
|
|
||||||
|
# NOTE unified training, static_chunk_size > 0 or = 0
|
||||||
|
streaming = True if random.random() < 0.5 else False
|
||||||
|
|
||||||
|
# xvec projection
|
||||||
|
embedding = F.normalize(embedding, dim=1)
|
||||||
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
|
|
||||||
|
# concat text and prompt_text
|
||||||
|
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
||||||
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||||
|
|
||||||
|
# text encode
|
||||||
|
h = self.pre_lookahead_layer(token)
|
||||||
|
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
|
||||||
|
mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
|
||||||
|
|
||||||
|
# get conditions
|
||||||
|
conds = torch.zeros(feat.shape, device=token.device)
|
||||||
|
for i, j in enumerate(feat_len):
|
||||||
|
if random.random() < 0.5:
|
||||||
|
continue
|
||||||
|
index = random.randint(0, int(0.3 * j))
|
||||||
|
conds[i, :index] = feat[i, :index]
|
||||||
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
|
loss, _ = self.decoder.compute_loss(
|
||||||
|
feat.transpose(1, 2).contiguous(),
|
||||||
|
mask.unsqueeze(1),
|
||||||
|
h.transpose(1, 2).contiguous(),
|
||||||
|
embedding,
|
||||||
|
cond=conds,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
return {'loss': loss}
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference(self,
|
||||||
|
token,
|
||||||
|
token_len,
|
||||||
|
prompt_token,
|
||||||
|
prompt_token_len,
|
||||||
|
prompt_feat,
|
||||||
|
prompt_feat_len,
|
||||||
|
embedding,
|
||||||
|
streaming,
|
||||||
|
finalize):
|
||||||
|
assert token.shape[0] == 1
|
||||||
|
# xvec projection
|
||||||
|
embedding = F.normalize(embedding, dim=1)
|
||||||
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
|
|
||||||
|
# concat text and prompt_text
|
||||||
|
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||||
|
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||||
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||||
|
|
||||||
|
# text encode
|
||||||
|
if finalize is True:
|
||||||
|
h = self.pre_lookahead_layer(token)
|
||||||
|
else:
|
||||||
|
h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
|
||||||
|
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
|
||||||
|
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||||
|
|
||||||
|
# get conditions
|
||||||
|
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||||
|
conds[:, :mel_len1] = prompt_feat
|
||||||
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
|
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||||
|
feat, _ = self.decoder(
|
||||||
|
mu=h.transpose(1, 2).contiguous(),
|
||||||
|
mask=mask.unsqueeze(1),
|
||||||
|
spks=embedding,
|
||||||
|
cond=conds,
|
||||||
|
n_timesteps=10,
|
||||||
|
streaming=streaming
|
||||||
|
)
|
||||||
|
feat = feat[:, :, mel_len1:]
|
||||||
|
assert feat.shape[2] == mel_len2
|
||||||
|
return feat.float(), None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
|
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
|
||||||
|
configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
|
||||||
|
model = configs['flow']
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
max_len = 10 * model.decoder.estimator.static_chunk_size
|
||||||
|
chunk_size = model.decoder.estimator.static_chunk_size
|
||||||
|
context_size = model.pre_lookahead_layer.pre_lookahead_len
|
||||||
|
token = torch.randint(0, 6561, size=(1, max_len)).to(device)
|
||||||
|
token_len = torch.tensor([max_len]).to(device)
|
||||||
|
prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
|
||||||
|
prompt_token_len = torch.tensor([chunk_size]).to(device)
|
||||||
|
prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
|
||||||
|
prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
|
||||||
|
prompt_embedding = torch.rand(1, 192).to(device)
|
||||||
|
pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
|
||||||
|
for i in range(0, max_len, chunk_size):
|
||||||
|
finalize = True if i + chunk_size + context_size >= max_len else False
|
||||||
|
pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
|
||||||
|
prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
|
||||||
|
pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
|
||||||
|
print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())
|
||||||
|
|||||||
129
cosyvoice/flow/flow_matching.py
Executable file → Normal file
129
cosyvoice/flow/flow_matching.py
Executable file → Normal file
@@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||||
|
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -14,6 +15,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from matcha.models.components.flow_matching import BASECFM
|
from matcha.models.components.flow_matching import BASECFM
|
||||||
|
from cosyvoice.utils.common import set_all_random_seed
|
||||||
|
|
||||||
|
|
||||||
class ConditionalCFM(BASECFM):
|
class ConditionalCFM(BASECFM):
|
||||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||||
@@ -31,7 +34,7 @@ class ConditionalCFM(BASECFM):
|
|||||||
self.estimator = estimator
|
self.estimator = estimator
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
||||||
"""Forward diffusion
|
"""Forward diffusion
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -49,13 +52,23 @@ class ConditionalCFM(BASECFM):
|
|||||||
sample: generated mel-spectrogram
|
sample: generated mel-spectrogram
|
||||||
shape: (batch_size, n_feats, mel_timesteps)
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
"""
|
"""
|
||||||
z = torch.randn_like(mu) * temperature
|
|
||||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
||||||
|
cache_size = cache.shape[2]
|
||||||
|
# fix prompt and overlap part mu and z
|
||||||
|
if cache_size != 0:
|
||||||
|
z[:, :, :cache_size] = cache[:, :, :, 0]
|
||||||
|
mu[:, :, :cache_size] = cache[:, :, :, 1]
|
||||||
|
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
||||||
|
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
||||||
|
cache = torch.stack([z_cache, mu_cache], dim=-1)
|
||||||
|
|
||||||
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||||
if self.t_scheduler == 'cosine':
|
if self.t_scheduler == 'cosine':
|
||||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
|
||||||
|
|
||||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
||||||
"""
|
"""
|
||||||
Fixed euler solver for ODEs.
|
Fixed euler solver for ODEs.
|
||||||
Args:
|
Args:
|
||||||
@@ -71,32 +84,75 @@ class ConditionalCFM(BASECFM):
|
|||||||
cond: Not used but kept for future purposes
|
cond: Not used but kept for future purposes
|
||||||
"""
|
"""
|
||||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||||
|
t = t.unsqueeze(dim=0)
|
||||||
|
|
||||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||||
# Or in future might add like a return_all_steps flag
|
# Or in future might add like a return_all_steps flag
|
||||||
sol = []
|
sol = []
|
||||||
|
|
||||||
|
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||||
|
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
|
||||||
|
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||||
|
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||||
|
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||||
|
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
|
||||||
|
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
|
||||||
|
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||||
for step in range(1, len(t_span)):
|
for step in range(1, len(t_span)):
|
||||||
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
|
||||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||||
if self.inference_cfg_rate > 0:
|
x_in[:] = x
|
||||||
cfg_dphi_dt = self.estimator(
|
mask_in[:] = mask
|
||||||
x, mask,
|
mu_in[0] = mu
|
||||||
torch.zeros_like(mu), t,
|
t_in[:] = t.unsqueeze(0)
|
||||||
torch.zeros_like(spks) if spks is not None else None,
|
spks_in[0] = spks
|
||||||
torch.zeros_like(cond)
|
cond_in[0] = cond
|
||||||
)
|
dphi_dt = self.forward_estimator(
|
||||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
x_in, mask_in,
|
||||||
self.inference_cfg_rate * cfg_dphi_dt)
|
mu_in, t_in,
|
||||||
|
spks_in,
|
||||||
|
cond_in,
|
||||||
|
streaming
|
||||||
|
)
|
||||||
|
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||||
|
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||||
x = x + dt * dphi_dt
|
x = x + dt * dphi_dt
|
||||||
t = t + dt
|
t = t + dt
|
||||||
sol.append(x)
|
sol.append(x)
|
||||||
if step < len(t_span) - 1:
|
if step < len(t_span) - 1:
|
||||||
dt = t_span[step + 1] - t
|
dt = t_span[step + 1] - t
|
||||||
|
|
||||||
return sol[-1]
|
return sol[-1].float()
|
||||||
|
|
||||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
|
||||||
|
if isinstance(self.estimator, torch.nn.Module):
|
||||||
|
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
|
||||||
|
else:
|
||||||
|
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
|
||||||
|
# NOTE need to synchronize when switching stream
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
with stream:
|
||||||
|
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||||
|
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||||
|
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||||
|
estimator.set_input_shape('t', (2,))
|
||||||
|
estimator.set_input_shape('spks', (2, 80))
|
||||||
|
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||||
|
data_ptrs = [x.contiguous().data_ptr(),
|
||||||
|
mask.contiguous().data_ptr(),
|
||||||
|
mu.contiguous().data_ptr(),
|
||||||
|
t.contiguous().data_ptr(),
|
||||||
|
spks.contiguous().data_ptr(),
|
||||||
|
cond.contiguous().data_ptr(),
|
||||||
|
x.data_ptr()]
|
||||||
|
for i, j in enumerate(data_ptrs):
|
||||||
|
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||||
|
# run trt engine
|
||||||
|
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
self.estimator.release_estimator(estimator, stream)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
||||||
"""Computes diffusion loss
|
"""Computes diffusion loss
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -118,8 +174,7 @@ class ConditionalCFM(BASECFM):
|
|||||||
|
|
||||||
# random timestep
|
# random timestep
|
||||||
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
||||||
if self.t_scheduler == 'cosine':
|
|
||||||
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
|
||||||
# sample noise p(x_0)
|
# sample noise p(x_0)
|
||||||
z = torch.randn_like(x1)
|
z = torch.randn_like(x1)
|
||||||
|
|
||||||
@@ -133,6 +188,40 @@ class ConditionalCFM(BASECFM):
|
|||||||
spks = spks * cfg_mask.view(-1, 1)
|
spks = spks * cfg_mask.view(-1, 1)
|
||||||
cond = cond * cfg_mask.view(-1, 1, 1)
|
cond = cond * cfg_mask.view(-1, 1, 1)
|
||||||
|
|
||||||
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
||||||
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
||||||
return loss, y
|
return loss, y
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConditionalCFM(ConditionalCFM):
|
||||||
|
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||||
|
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
||||||
|
set_all_random_seed(0)
|
||||||
|
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
|
||||||
|
"""Forward diffusion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mu (torch.Tensor): output of encoder
|
||||||
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
|
mask (torch.Tensor): output_mask
|
||||||
|
shape: (batch_size, 1, mel_timesteps)
|
||||||
|
n_timesteps (int): number of diffusion steps
|
||||||
|
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||||
|
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||||
|
shape: (batch_size, spk_emb_dim)
|
||||||
|
cond: Not used but kept for future purposes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sample: generated mel-spectrogram
|
||||||
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
|
"""
|
||||||
|
|
||||||
|
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
||||||
|
# fix prompt and overlap part mu and z
|
||||||
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||||
|
if self.t_scheduler == 'cosine':
|
||||||
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||||
|
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
|
||||||
|
|||||||
23
cosyvoice/flow/length_regulator.py
Executable file → Normal file
23
cosyvoice/flow/length_regulator.py
Executable file → Normal file
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from cosyvoice.utils.mask import make_pad_mask
|
from cosyvoice.utils.mask import make_pad_mask
|
||||||
|
|
||||||
@@ -43,7 +44,27 @@ class InterpolateRegulator(nn.Module):
|
|||||||
def forward(self, x, ylens=None):
|
def forward(self, x, ylens=None):
|
||||||
# x in (B, T, D)
|
# x in (B, T, D)
|
||||||
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
||||||
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
||||||
out = self.model(x).transpose(1, 2).contiguous()
|
out = self.model(x).transpose(1, 2).contiguous()
|
||||||
olens = ylens
|
olens = ylens
|
||||||
return out * mask, olens
|
return out * mask, olens
|
||||||
|
|
||||||
|
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
||||||
|
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
||||||
|
# NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
|
||||||
|
# x in (B, T, D)
|
||||||
|
if x2.shape[1] > 40:
|
||||||
|
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
||||||
|
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
||||||
|
mode='linear')
|
||||||
|
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
||||||
|
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
||||||
|
else:
|
||||||
|
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
||||||
|
if x1.shape[1] != 0:
|
||||||
|
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
||||||
|
x = torch.concat([x1, x2], dim=2)
|
||||||
|
else:
|
||||||
|
x = x2
|
||||||
|
out = self.model(x).transpose(1, 2).contiguous()
|
||||||
|
return out, mel_len1 + mel_len2
|
||||||
|
|||||||
230
cosyvoice/hifigan/discriminator.py
Normal file
230
cosyvoice/hifigan/discriminator.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
try:
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
|
||||||
|
except ImportError:
|
||||||
|
from torch.nn.utils import weight_norm, spectral_norm
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
from einops import rearrange
|
||||||
|
from torchaudio.transforms import Spectrogram
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class MultipleDiscriminator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, mpd: nn.Module, mrd: nn.Module
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mpd = mpd
|
||||||
|
self.mrd = mrd
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
|
||||||
|
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
||||||
|
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
|
||||||
|
y_d_rs += this_y_d_rs
|
||||||
|
y_d_gs += this_y_d_gs
|
||||||
|
fmap_rs += this_fmap_rs
|
||||||
|
fmap_gs += this_fmap_gs
|
||||||
|
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
|
||||||
|
y_d_rs += this_y_d_rs
|
||||||
|
y_d_gs += this_y_d_gs
|
||||||
|
fmap_rs += this_fmap_rs
|
||||||
|
fmap_gs += this_fmap_gs
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
class MultiResolutionDiscriminator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
||||||
|
num_embeddings: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
||||||
|
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
||||||
|
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.discriminators = nn.ModuleList(
|
||||||
|
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
|
||||||
|
for d in self.discriminators:
|
||||||
|
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
||||||
|
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorR(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_length: int,
|
||||||
|
num_embeddings: Optional[int] = None,
|
||||||
|
channels: int = 32,
|
||||||
|
hop_factor: float = 0.25,
|
||||||
|
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.window_length = window_length
|
||||||
|
self.hop_factor = hop_factor
|
||||||
|
self.spec_fn = Spectrogram(
|
||||||
|
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
||||||
|
)
|
||||||
|
n_fft = window_length // 2 + 1
|
||||||
|
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
||||||
|
self.bands = bands
|
||||||
|
convs = lambda: nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||||
|
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||||
|
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||||
|
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||||
|
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||||
|
|
||||||
|
if num_embeddings is not None:
|
||||||
|
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
||||||
|
torch.nn.init.zeros_(self.emb.weight)
|
||||||
|
|
||||||
|
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
||||||
|
|
||||||
|
def spectrogram(self, x):
|
||||||
|
# Remove DC offset
|
||||||
|
x = x - x.mean(dim=-1, keepdims=True)
|
||||||
|
# Peak normalize the volume of input audio
|
||||||
|
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||||
|
x = self.spec_fn(x)
|
||||||
|
x = torch.view_as_real(x)
|
||||||
|
x = rearrange(x, "b f t c -> b c t f")
|
||||||
|
# Split into bands
|
||||||
|
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
|
||||||
|
return x_bands
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
||||||
|
x_bands = self.spectrogram(x)
|
||||||
|
fmap = []
|
||||||
|
x = []
|
||||||
|
for band, stack in zip(x_bands, self.band_convs):
|
||||||
|
for i, layer in enumerate(stack):
|
||||||
|
band = layer(band)
|
||||||
|
band = torch.nn.functional.leaky_relu(band, 0.1)
|
||||||
|
if i > 0:
|
||||||
|
fmap.append(band)
|
||||||
|
x.append(band)
|
||||||
|
x = torch.cat(x, dim=-1)
|
||||||
|
if cond_embedding_id is not None:
|
||||||
|
emb = self.emb(cond_embedding_id)
|
||||||
|
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
||||||
|
else:
|
||||||
|
h = 0
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x += h
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiResSpecDiscriminator(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
fft_sizes=[1024, 2048, 512],
|
||||||
|
hop_sizes=[120, 240, 50],
|
||||||
|
win_lengths=[600, 1200, 240],
|
||||||
|
window="hann_window"):
|
||||||
|
|
||||||
|
super(MultiResSpecDiscriminator, self).__init__()
|
||||||
|
self.discriminators = nn.ModuleList([
|
||||||
|
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
||||||
|
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
||||||
|
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
|
||||||
|
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
for _, d in enumerate(self.discriminators):
|
||||||
|
y_d_r, fmap_r = d(y)
|
||||||
|
y_d_g, fmap_g = d(y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
def stft(x, fft_size, hop_size, win_length, window):
|
||||||
|
"""Perform STFT and convert to magnitude spectrogram.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input signal tensor (B, T).
|
||||||
|
fft_size (int): FFT size.
|
||||||
|
hop_size (int): Hop size.
|
||||||
|
win_length (int): Window length.
|
||||||
|
window (str): Window function type.
|
||||||
|
Returns:
|
||||||
|
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||||
|
"""
|
||||||
|
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
|
||||||
|
|
||||||
|
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
||||||
|
return torch.abs(x_stft).transpose(2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class SpecDiscriminator(nn.Module):
|
||||||
|
"""docstring for Discriminator."""
|
||||||
|
|
||||||
|
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
|
||||||
|
super(SpecDiscriminator, self).__init__()
|
||||||
|
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||||
|
self.fft_size = fft_size
|
||||||
|
self.shift_size = shift_size
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = getattr(torch, window)(win_length)
|
||||||
|
self.discriminators = nn.ModuleList([
|
||||||
|
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||||
|
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
|
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
|
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
|
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
||||||
|
|
||||||
|
def forward(self, y):
|
||||||
|
|
||||||
|
fmap = []
|
||||||
|
y = y.squeeze(1)
|
||||||
|
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
|
||||||
|
y = y.unsqueeze(1)
|
||||||
|
for _, d in enumerate(self.discriminators):
|
||||||
|
y = d(y)
|
||||||
|
y = F.leaky_relu(y, LRELU_SLOPE)
|
||||||
|
fmap.append(y)
|
||||||
|
|
||||||
|
y = self.out(y)
|
||||||
|
fmap.append(y)
|
||||||
|
|
||||||
|
return torch.flatten(y, 1, -1), fmap
|
||||||
50
cosyvoice/hifigan/f0_predictor.py
Executable file → Normal file
50
cosyvoice/hifigan/f0_predictor.py
Executable file → Normal file
@@ -13,7 +13,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.utils import weight_norm
|
try:
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
except ImportError:
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
from cosyvoice.transformer.convolution import CausalConv1d
|
||||||
|
|
||||||
|
|
||||||
class ConvRNNF0Predictor(nn.Module):
|
class ConvRNNF0Predictor(nn.Module):
|
||||||
@@ -53,3 +57,47 @@ class ConvRNNF0Predictor(nn.Module):
|
|||||||
x = self.condnet(x)
|
x = self.condnet(x)
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
return torch.abs(self.classifier(x).squeeze(-1))
|
return torch.abs(self.classifier(x).squeeze(-1))
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConvRNNF0Predictor(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_class: int = 1,
|
||||||
|
in_channels: int = 80,
|
||||||
|
cond_channels: int = 512
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_class = num_class
|
||||||
|
self.condnet = nn.Sequential(
|
||||||
|
weight_norm(
|
||||||
|
CausalConv1d(in_channels, cond_channels, kernel_size=4, causal_type='right')
|
||||||
|
),
|
||||||
|
nn.ELU(),
|
||||||
|
weight_norm(
|
||||||
|
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||||
|
),
|
||||||
|
nn.ELU(),
|
||||||
|
weight_norm(
|
||||||
|
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||||
|
),
|
||||||
|
nn.ELU(),
|
||||||
|
weight_norm(
|
||||||
|
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||||
|
),
|
||||||
|
nn.ELU(),
|
||||||
|
weight_norm(
|
||||||
|
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||||
|
),
|
||||||
|
nn.ELU(),
|
||||||
|
)
|
||||||
|
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, finalize: bool = True) -> torch.Tensor:
|
||||||
|
if finalize is True:
|
||||||
|
x = self.condnet[0](x)
|
||||||
|
else:
|
||||||
|
x = self.condnet[0](x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:])
|
||||||
|
for i in range(1, len(self.condnet)):
|
||||||
|
x = self.condnet[i](x)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
return torch.abs(self.classifier(x).squeeze(-1))
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
"""HIFI-GAN"""
|
"""HIFI-GAN"""
|
||||||
|
|
||||||
import typing as tp
|
from typing import Dict, Optional, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.signal import get_window
|
from scipy.signal import get_window
|
||||||
import torch
|
import torch
|
||||||
@@ -23,9 +23,12 @@ import torch.nn.functional as F
|
|||||||
from torch.nn import Conv1d
|
from torch.nn import Conv1d
|
||||||
from torch.nn import ConvTranspose1d
|
from torch.nn import ConvTranspose1d
|
||||||
from torch.nn.utils import remove_weight_norm
|
from torch.nn.utils import remove_weight_norm
|
||||||
from torch.nn.utils import weight_norm
|
try:
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
except ImportError:
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
from torch.distributions.uniform import Uniform
|
from torch.distributions.uniform import Uniform
|
||||||
|
from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
|
||||||
from cosyvoice.transformer.activation import Snake
|
from cosyvoice.transformer.activation import Snake
|
||||||
from cosyvoice.utils.common import get_padding
|
from cosyvoice.utils.common import get_padding
|
||||||
from cosyvoice.utils.common import init_weights
|
from cosyvoice.utils.common import init_weights
|
||||||
@@ -38,15 +41,19 @@ This code is modified from https://github.com/jik876/hifi-gan
|
|||||||
https://github.com/NVIDIA/BigVGAN
|
https://github.com/NVIDIA/BigVGAN
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(torch.nn.Module):
|
class ResBlock(torch.nn.Module):
|
||||||
"""Residual block module in HiFiGAN/BigVGAN."""
|
"""Residual block module in HiFiGAN/BigVGAN."""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int = 512,
|
channels: int = 512,
|
||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
dilations: tp.List[int] = [1, 3, 5],
|
dilations: List[int] = [1, 3, 5],
|
||||||
|
causal: bool = False,
|
||||||
):
|
):
|
||||||
super(ResBlock, self).__init__()
|
super(ResBlock, self).__init__()
|
||||||
|
self.causal = causal
|
||||||
self.convs1 = nn.ModuleList()
|
self.convs1 = nn.ModuleList()
|
||||||
self.convs2 = nn.ModuleList()
|
self.convs2 = nn.ModuleList()
|
||||||
|
|
||||||
@@ -59,7 +66,14 @@ class ResBlock(torch.nn.Module):
|
|||||||
kernel_size,
|
kernel_size,
|
||||||
1,
|
1,
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
padding=get_padding(kernel_size, dilation)
|
padding=get_padding(kernel_size, dilation)) if causal is False else
|
||||||
|
CausalConv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation,
|
||||||
|
causal_type='left'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -71,7 +85,14 @@ class ResBlock(torch.nn.Module):
|
|||||||
kernel_size,
|
kernel_size,
|
||||||
1,
|
1,
|
||||||
dilation=1,
|
dilation=1,
|
||||||
padding=get_padding(kernel_size, 1)
|
padding=get_padding(kernel_size, 1)) if causal is False else
|
||||||
|
CausalConv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
causal_type='left'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -100,6 +121,7 @@ class ResBlock(torch.nn.Module):
|
|||||||
remove_weight_norm(self.convs1[idx])
|
remove_weight_norm(self.convs1[idx])
|
||||||
remove_weight_norm(self.convs2[idx])
|
remove_weight_norm(self.convs2[idx])
|
||||||
|
|
||||||
|
|
||||||
class SineGen(torch.nn.Module):
|
class SineGen(torch.nn.Module):
|
||||||
""" Definition of sine generator
|
""" Definition of sine generator
|
||||||
SineGen(samp_rate, harmonic_num = 0,
|
SineGen(samp_rate, harmonic_num = 0,
|
||||||
@@ -133,11 +155,13 @@ class SineGen(torch.nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, f0):
|
def forward(self, f0):
|
||||||
|
""" sine_tensor, uv = forward(f0)
|
||||||
|
input F0: tensor(batchsize=1, dim=1, length)
|
||||||
|
f0 for unvoiced steps should be 0
|
||||||
|
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||||
|
output uv: tensor(batchsize=1, length, 1)
|
||||||
"""
|
"""
|
||||||
:param f0: [B, 1, sample_len], Hz
|
f0 = f0.transpose(1, 2)
|
||||||
:return: [B, 1, sample_len]
|
|
||||||
"""
|
|
||||||
|
|
||||||
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
||||||
for i in range(self.harmonic_num + 1):
|
for i in range(self.harmonic_num + 1):
|
||||||
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
||||||
@@ -159,6 +183,134 @@ class SineGen(torch.nn.Module):
|
|||||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||||
noise = noise_amp * torch.randn_like(sine_waves)
|
noise = noise_amp * torch.randn_like(sine_waves)
|
||||||
|
|
||||||
|
# first: set the unvoiced part to 0 by uv
|
||||||
|
# then: additive noise
|
||||||
|
sine_waves = sine_waves * uv + noise
|
||||||
|
return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
|
||||||
|
|
||||||
|
|
||||||
|
class SineGen2(torch.nn.Module):
|
||||||
|
""" Definition of sine generator
|
||||||
|
SineGen(samp_rate, harmonic_num = 0,
|
||||||
|
sine_amp = 0.1, noise_std = 0.003,
|
||||||
|
voiced_threshold = 0,
|
||||||
|
flag_for_pulse=False)
|
||||||
|
samp_rate: sampling rate in Hz
|
||||||
|
harmonic_num: number of harmonic overtones (default 0)
|
||||||
|
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||||
|
noise_std: std of Gaussian noise (default 0.003)
|
||||||
|
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||||
|
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||||
|
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||||
|
segment is always sin(np.pi) or cos(0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||||
|
sine_amp=0.1, noise_std=0.003,
|
||||||
|
voiced_threshold=0,
|
||||||
|
flag_for_pulse=False,
|
||||||
|
causal=False):
|
||||||
|
super(SineGen2, self).__init__()
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = noise_std
|
||||||
|
self.harmonic_num = harmonic_num
|
||||||
|
self.dim = self.harmonic_num + 1
|
||||||
|
self.sampling_rate = samp_rate
|
||||||
|
self.voiced_threshold = voiced_threshold
|
||||||
|
self.flag_for_pulse = flag_for_pulse
|
||||||
|
self.upsample_scale = upsample_scale
|
||||||
|
self.causal = causal
|
||||||
|
if causal is True:
|
||||||
|
self.rand_ini = torch.rand(1, 9)
|
||||||
|
self.rand_ini[:, 0] = 0
|
||||||
|
self.sine_waves = torch.rand(1, 300 * 24000, 9)
|
||||||
|
|
||||||
|
def _f02uv(self, f0):
|
||||||
|
# generate uv signal
|
||||||
|
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||||
|
return uv
|
||||||
|
|
||||||
|
def _f02sine(self, f0_values):
|
||||||
|
""" f0_values: (batchsize, length, dim)
|
||||||
|
where dim indicates fundamental tone and overtones
|
||||||
|
"""
|
||||||
|
# convert to F0 in rad. The interger part n can be ignored
|
||||||
|
# because 2 * np.pi * n doesn't affect phase
|
||||||
|
rad_values = (f0_values / self.sampling_rate) % 1
|
||||||
|
|
||||||
|
# initial phase noise (no noise for fundamental component)
|
||||||
|
if self.training is False and self.causal is True:
|
||||||
|
rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
|
||||||
|
else:
|
||||||
|
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
||||||
|
rand_ini[:, 0] = 0
|
||||||
|
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||||
|
|
||||||
|
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||||
|
if not self.flag_for_pulse:
|
||||||
|
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
||||||
|
scale_factor=1 / self.upsample_scale,
|
||||||
|
mode="linear").transpose(1, 2)
|
||||||
|
|
||||||
|
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||||
|
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
||||||
|
scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
|
||||||
|
sines = torch.sin(phase)
|
||||||
|
else:
|
||||||
|
# If necessary, make sure that the first time step of every
|
||||||
|
# voiced segments is sin(pi) or cos(0)
|
||||||
|
# This is used for pulse-train generation
|
||||||
|
|
||||||
|
# identify the last time step in unvoiced segments
|
||||||
|
uv = self._f02uv(f0_values)
|
||||||
|
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||||
|
uv_1[:, -1, :] = 1
|
||||||
|
u_loc = (uv < 1) * (uv_1 > 0)
|
||||||
|
|
||||||
|
# get the instantanouse phase
|
||||||
|
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||||
|
# different batch needs to be processed differently
|
||||||
|
for idx in range(f0_values.shape[0]):
|
||||||
|
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||||
|
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||||
|
# stores the accumulation of i.phase within
|
||||||
|
# each voiced segments
|
||||||
|
tmp_cumsum[idx, :, :] = 0
|
||||||
|
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||||
|
|
||||||
|
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||||
|
# within the previous voiced segment.
|
||||||
|
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||||
|
|
||||||
|
# get the sines
|
||||||
|
sines = torch.cos(i_phase * 2 * np.pi)
|
||||||
|
return sines
|
||||||
|
|
||||||
|
def forward(self, f0):
|
||||||
|
""" sine_tensor, uv = forward(f0)
|
||||||
|
input F0: tensor(batchsize=1, length, dim=1)
|
||||||
|
f0 for unvoiced steps should be 0
|
||||||
|
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||||
|
output uv: tensor(batchsize=1, length, 1)
|
||||||
|
"""
|
||||||
|
# fundamental component
|
||||||
|
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||||
|
|
||||||
|
# generate sine waveforms
|
||||||
|
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||||
|
|
||||||
|
# generate uv signal
|
||||||
|
uv = self._f02uv(f0)
|
||||||
|
|
||||||
|
# noise: for unvoiced should be similar to sine_amp
|
||||||
|
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||||
|
# . for voiced regions is self.noise_std
|
||||||
|
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||||
|
if self.training is False and self.causal is True:
|
||||||
|
noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
|
||||||
|
else:
|
||||||
|
noise = noise_amp * torch.randn_like(sine_waves)
|
||||||
|
|
||||||
# first: set the unvoiced part to 0 by uv
|
# first: set the unvoiced part to 0 by uv
|
||||||
# then: additive noise
|
# then: additive noise
|
||||||
sine_waves = sine_waves * uv + noise
|
sine_waves = sine_waves * uv + noise
|
||||||
@@ -184,19 +336,24 @@ class SourceModuleHnNSF(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||||
add_noise_std=0.003, voiced_threshod=0):
|
add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
|
||||||
super(SourceModuleHnNSF, self).__init__()
|
super(SourceModuleHnNSF, self).__init__()
|
||||||
|
|
||||||
self.sine_amp = sine_amp
|
self.sine_amp = sine_amp
|
||||||
self.noise_std = add_noise_std
|
self.noise_std = add_noise_std
|
||||||
|
|
||||||
# to produce sine waveforms
|
# to produce sine waveforms
|
||||||
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
if sinegen_type == '1':
|
||||||
sine_amp, add_noise_std, voiced_threshod)
|
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
||||||
|
else:
|
||||||
|
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
|
||||||
|
|
||||||
# to merge source harmonics into a single excitation
|
# to merge source harmonics into a single excitation
|
||||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||||
self.l_tanh = torch.nn.Tanh()
|
self.l_tanh = torch.nn.Tanh()
|
||||||
|
self.causal = causal
|
||||||
|
if causal is True:
|
||||||
|
self.uv = torch.rand(1, 300 * 24000, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
@@ -207,13 +364,14 @@ class SourceModuleHnNSF(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
# source for harmonic branch
|
# source for harmonic branch
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||||
sine_wavs = sine_wavs.transpose(1, 2)
|
|
||||||
uv = uv.transpose(1, 2)
|
|
||||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||||
|
|
||||||
# source for noise branch, in the same shape as uv
|
# source for noise branch, in the same shape as uv
|
||||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
if self.training is False and self.causal is True:
|
||||||
|
noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
|
||||||
|
else:
|
||||||
|
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||||
return sine_merge, noise, uv
|
return sine_merge, noise, uv
|
||||||
|
|
||||||
|
|
||||||
@@ -231,13 +389,13 @@ class HiFTGenerator(nn.Module):
|
|||||||
nsf_alpha: float = 0.1,
|
nsf_alpha: float = 0.1,
|
||||||
nsf_sigma: float = 0.003,
|
nsf_sigma: float = 0.003,
|
||||||
nsf_voiced_threshold: float = 10,
|
nsf_voiced_threshold: float = 10,
|
||||||
upsample_rates: tp.List[int] = [8, 8],
|
upsample_rates: List[int] = [8, 8],
|
||||||
upsample_kernel_sizes: tp.List[int] = [16, 16],
|
upsample_kernel_sizes: List[int] = [16, 16],
|
||||||
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
||||||
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
||||||
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
|
source_resblock_kernel_sizes: List[int] = [7, 11],
|
||||||
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
|
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
||||||
lrelu_slope: float = 0.1,
|
lrelu_slope: float = 0.1,
|
||||||
audio_limit: float = 0.99,
|
audio_limit: float = 0.99,
|
||||||
f0_predictor: torch.nn.Module = None,
|
f0_predictor: torch.nn.Module = None,
|
||||||
@@ -253,13 +411,16 @@ class HiFTGenerator(nn.Module):
|
|||||||
|
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_rates)
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
# NOTE in CosyVoice2, we use the original SineGen implementation
|
||||||
self.m_source = SourceModuleHnNSF(
|
self.m_source = SourceModuleHnNSF(
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
||||||
harmonic_num=nb_harmonics,
|
harmonic_num=nb_harmonics,
|
||||||
sine_amp=nsf_alpha,
|
sine_amp=nsf_alpha,
|
||||||
add_noise_std=nsf_sigma,
|
add_noise_std=nsf_sigma,
|
||||||
voiced_threshod=nsf_voiced_threshold)
|
voiced_threshod=nsf_voiced_threshold,
|
||||||
|
sinegen_type='1' if self.sampling_rate == 22050 else '2',
|
||||||
|
causal=False)
|
||||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
||||||
|
|
||||||
self.conv_pre = weight_norm(
|
self.conv_pre = weight_norm(
|
||||||
@@ -286,8 +447,7 @@ class HiFTGenerator(nn.Module):
|
|||||||
self.source_resblocks = nn.ModuleList()
|
self.source_resblocks = nn.ModuleList()
|
||||||
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
||||||
downsample_cum_rates = np.cumprod(downsample_rates)
|
downsample_cum_rates = np.cumprod(downsample_rates)
|
||||||
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
|
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
||||||
source_resblock_dilation_sizes)):
|
|
||||||
if u == 1:
|
if u == 1:
|
||||||
self.source_downs.append(
|
self.source_downs.append(
|
||||||
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
||||||
@@ -304,7 +464,7 @@ class HiFTGenerator(nn.Module):
|
|||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = base_channels // (2**(i + 1))
|
ch = base_channels // (2**(i + 1))
|
||||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
self.resblocks.append(ResBlock(ch, k, d))
|
self.resblocks.append(ResBlock(ch, k, d))
|
||||||
|
|
||||||
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
||||||
@@ -314,11 +474,19 @@ class HiFTGenerator(nn.Module):
|
|||||||
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
||||||
self.f0_predictor = f0_predictor
|
self.f0_predictor = f0_predictor
|
||||||
|
|
||||||
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
|
def remove_weight_norm(self):
|
||||||
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
print('Removing weight norm...')
|
||||||
|
for l in self.ups:
|
||||||
har_source, _, _ = self.m_source(f0)
|
remove_weight_norm(l)
|
||||||
return har_source.transpose(1, 2)
|
for l in self.resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
remove_weight_norm(self.conv_pre)
|
||||||
|
remove_weight_norm(self.conv_post)
|
||||||
|
self.m_source.remove_weight_norm()
|
||||||
|
for l in self.source_downs:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.source_resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
|
||||||
def _stft(self, x):
|
def _stft(self, x):
|
||||||
spec = torch.stft(
|
spec = torch.stft(
|
||||||
@@ -332,13 +500,11 @@ class HiFTGenerator(nn.Module):
|
|||||||
magnitude = torch.clip(magnitude, max=1e2)
|
magnitude = torch.clip(magnitude, max=1e2)
|
||||||
real = magnitude * torch.cos(phase)
|
real = magnitude * torch.cos(phase)
|
||||||
img = magnitude * torch.sin(phase)
|
img = magnitude * torch.sin(phase)
|
||||||
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
||||||
|
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
||||||
return inverse_transform
|
return inverse_transform
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||||
f0 = self.f0_predictor(x)
|
|
||||||
s = self._f02source(f0)
|
|
||||||
|
|
||||||
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
||||||
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
||||||
|
|
||||||
@@ -372,20 +538,209 @@ class HiFTGenerator(nn.Module):
|
|||||||
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def forward(
|
||||||
print('Removing weight norm...')
|
self,
|
||||||
for l in self.ups:
|
batch: dict,
|
||||||
remove_weight_norm(l)
|
device: torch.device,
|
||||||
for l in self.resblocks:
|
) -> Dict[str, Optional[torch.Tensor]]:
|
||||||
l.remove_weight_norm()
|
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
|
||||||
remove_weight_norm(self.conv_pre)
|
# mel->f0
|
||||||
remove_weight_norm(self.conv_post)
|
f0 = self.f0_predictor(speech_feat)
|
||||||
self.source_module.remove_weight_norm()
|
# f0->source
|
||||||
for l in self.source_downs:
|
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||||
remove_weight_norm(l)
|
s, _, _ = self.m_source(s)
|
||||||
for l in self.source_resblocks:
|
s = s.transpose(1, 2)
|
||||||
l.remove_weight_norm()
|
# mel+source->speech
|
||||||
|
generated_speech = self.decode(x=speech_feat, s=s)
|
||||||
|
return generated_speech, f0
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(self, mel: torch.Tensor) -> torch.Tensor:
|
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||||
return self.forward(x=mel)
|
# mel->f0
|
||||||
|
f0 = self.f0_predictor(speech_feat)
|
||||||
|
# f0->source
|
||||||
|
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||||
|
s, _, _ = self.m_source(s)
|
||||||
|
s = s.transpose(1, 2)
|
||||||
|
# use cache_source to avoid glitch
|
||||||
|
if cache_source.shape[2] != 0:
|
||||||
|
s[:, :, :cache_source.shape[2]] = cache_source
|
||||||
|
generated_speech = self.decode(x=speech_feat, s=s)
|
||||||
|
return generated_speech, s
|
||||||
|
|
||||||
|
|
||||||
|
class CausalHiFTGenerator(HiFTGenerator):
|
||||||
|
"""
|
||||||
|
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
||||||
|
https://arxiv.org/abs/2309.09493
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 80,
|
||||||
|
base_channels: int = 512,
|
||||||
|
nb_harmonics: int = 8,
|
||||||
|
sampling_rate: int = 22050,
|
||||||
|
nsf_alpha: float = 0.1,
|
||||||
|
nsf_sigma: float = 0.003,
|
||||||
|
nsf_voiced_threshold: float = 10,
|
||||||
|
upsample_rates: List[int] = [8, 8],
|
||||||
|
upsample_kernel_sizes: List[int] = [16, 16],
|
||||||
|
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
||||||
|
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
||||||
|
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
source_resblock_kernel_sizes: List[int] = [7, 11],
|
||||||
|
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
||||||
|
lrelu_slope: float = 0.1,
|
||||||
|
audio_limit: float = 0.99,
|
||||||
|
conv_pre_look_right: int = 4,
|
||||||
|
f0_predictor: torch.nn.Module = None,
|
||||||
|
):
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
|
||||||
|
self.out_channels = 1
|
||||||
|
self.nb_harmonics = nb_harmonics
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.istft_params = istft_params
|
||||||
|
self.lrelu_slope = lrelu_slope
|
||||||
|
self.audio_limit = audio_limit
|
||||||
|
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
self.m_source = SourceModuleHnNSF(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
||||||
|
harmonic_num=nb_harmonics,
|
||||||
|
sine_amp=nsf_alpha,
|
||||||
|
add_noise_std=nsf_sigma,
|
||||||
|
voiced_threshod=nsf_voiced_threshold,
|
||||||
|
sinegen_type='1' if self.sampling_rate == 22050 else '2',
|
||||||
|
causal=True)
|
||||||
|
self.upsample_rates = upsample_rates
|
||||||
|
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
||||||
|
|
||||||
|
self.conv_pre = weight_norm(
|
||||||
|
CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Up
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
weight_norm(
|
||||||
|
CausalConv1dUpsample(
|
||||||
|
base_channels // (2**i),
|
||||||
|
base_channels // (2**(i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Down
|
||||||
|
self.source_downs = nn.ModuleList()
|
||||||
|
self.source_resblocks = nn.ModuleList()
|
||||||
|
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
||||||
|
downsample_cum_rates = np.cumprod(downsample_rates)
|
||||||
|
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
||||||
|
if u == 1:
|
||||||
|
self.source_downs.append(
|
||||||
|
CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.source_downs.append(
|
||||||
|
CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.source_resblocks.append(
|
||||||
|
ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = base_channels // (2**(i + 1))
|
||||||
|
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(ResBlock(ch, k, d, causal=True))
|
||||||
|
|
||||||
|
self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
|
||||||
|
self.ups.apply(init_weights)
|
||||||
|
self.conv_post.apply(init_weights)
|
||||||
|
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
||||||
|
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
||||||
|
self.conv_pre_look_right = conv_pre_look_right
|
||||||
|
self.f0_predictor = f0_predictor
|
||||||
|
|
||||||
|
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
|
||||||
|
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
||||||
|
if finalize is True:
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
else:
|
||||||
|
x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
|
||||||
|
s_stft_real = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
|
||||||
|
s_stft_imag = s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
|
||||||
|
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, self.lrelu_slope)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
|
||||||
|
if i == self.num_upsamples - 1:
|
||||||
|
x = self.reflection_pad(x)
|
||||||
|
|
||||||
|
# fusion
|
||||||
|
si = self.source_downs[i](s_stft)
|
||||||
|
si = self.source_resblocks[i](si)
|
||||||
|
x = x + si
|
||||||
|
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
||||||
|
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
||||||
|
|
||||||
|
x = self._istft(magnitude, phase)
|
||||||
|
if finalize is False:
|
||||||
|
x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
|
||||||
|
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
|
||||||
|
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
|
||||||
|
self.f0_predictor.to(torch.float64)
|
||||||
|
f0 = self.f0_predictor(speech_feat.to(torch.float64), finalize=finalize).to(speech_feat)
|
||||||
|
# f0->source
|
||||||
|
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||||
|
s, _, _ = self.m_source(s)
|
||||||
|
s = s.transpose(1, 2)
|
||||||
|
if finalize is True:
|
||||||
|
generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
|
||||||
|
else:
|
||||||
|
generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
|
||||||
|
return generated_speech, s
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
|
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
|
||||||
|
configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
|
||||||
|
model = configs['hift']
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
max_len, chunk_size, context_size = 300, 30, 8
|
||||||
|
mel = torch.rand(1, 80, max_len).to(device)
|
||||||
|
pred_gt, _ = model.inference(mel)
|
||||||
|
for i in range(0, max_len, chunk_size):
|
||||||
|
finalize = True if i + chunk_size + context_size >= max_len else False
|
||||||
|
pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
|
||||||
|
pred_chunk = pred_chunk[:, i * 480:]
|
||||||
|
print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())
|
||||||
|
|||||||
67
cosyvoice/hifigan/hifigan.py
Normal file
67
cosyvoice/hifigan/hifigan.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
|
||||||
|
from cosyvoice.utils.losses import tpr_loss, mel_loss
|
||||||
|
|
||||||
|
|
||||||
|
class HiFiGan(nn.Module):
|
||||||
|
def __init__(self, generator, discriminator, mel_spec_transform,
|
||||||
|
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
|
||||||
|
tpr_loss_weight=1.0, tpr_loss_tau=0.04):
|
||||||
|
super(HiFiGan, self).__init__()
|
||||||
|
self.generator = generator
|
||||||
|
self.discriminator = discriminator
|
||||||
|
self.mel_spec_transform = mel_spec_transform
|
||||||
|
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
|
||||||
|
self.feat_match_loss_weight = feat_match_loss_weight
|
||||||
|
self.tpr_loss_weight = tpr_loss_weight
|
||||||
|
self.tpr_loss_tau = tpr_loss_tau
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: dict,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, Optional[torch.Tensor]]:
|
||||||
|
if batch['turn'] == 'generator':
|
||||||
|
return self.forward_generator(batch, device)
|
||||||
|
else:
|
||||||
|
return self.forward_discriminator(batch, device)
|
||||||
|
|
||||||
|
def forward_generator(self, batch, device):
|
||||||
|
real_speech = batch['speech'].to(device)
|
||||||
|
pitch_feat = batch['pitch_feat'].to(device)
|
||||||
|
# 1. calculate generator outputs
|
||||||
|
generated_speech, generated_f0 = self.generator(batch, device)
|
||||||
|
# 2. calculate discriminator outputs
|
||||||
|
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
||||||
|
# 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
|
||||||
|
loss_gen, _ = generator_loss(y_d_gs)
|
||||||
|
loss_fm = feature_loss(fmap_rs, fmap_gs)
|
||||||
|
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
|
||||||
|
if self.tpr_loss_weight != 0:
|
||||||
|
loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
|
||||||
|
else:
|
||||||
|
loss_tpr = torch.zeros(1).to(device)
|
||||||
|
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
||||||
|
loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
|
||||||
|
self.multi_mel_spectral_recon_loss_weight * loss_mel + \
|
||||||
|
self.tpr_loss_weight * loss_tpr + loss_f0
|
||||||
|
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
||||||
|
|
||||||
|
def forward_discriminator(self, batch, device):
|
||||||
|
real_speech = batch['speech'].to(device)
|
||||||
|
# 1. calculate generator outputs
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_speech, generated_f0 = self.generator(batch, device)
|
||||||
|
# 2. calculate discriminator outputs
|
||||||
|
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
|
||||||
|
# 3. calculate discriminator losses, tpr losses [Optional]
|
||||||
|
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
||||||
|
if self.tpr_loss_weight != 0:
|
||||||
|
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
||||||
|
else:
|
||||||
|
loss_tpr = torch.zeros(1).to(device)
|
||||||
|
loss = loss_disc + self.tpr_loss_weight * loss_tpr
|
||||||
|
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||||
|
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -11,14 +12,23 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Dict, Optional, Union
|
import os, queue
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from typing import Dict, Optional, Callable, List, Generator
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from transformers import Qwen2ForCausalLM
|
||||||
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
||||||
from cosyvoice.utils.common import IGNORE_ID
|
from cosyvoice.utils.common import IGNORE_ID
|
||||||
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
||||||
from cosyvoice.utils.common import th_accuracy
|
from cosyvoice.utils.common import th_accuracy
|
||||||
|
from cosyvoice.utils.file_utils import logging
|
||||||
|
from cosyvoice.utils.mask import make_pad_mask
|
||||||
|
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
|
||||||
|
|
||||||
|
|
||||||
class TransformerLM(torch.nn.Module):
|
class TransformerLM(torch.nn.Module):
|
||||||
@@ -31,6 +41,7 @@ class TransformerLM(torch.nn.Module):
|
|||||||
speech_token_size: int,
|
speech_token_size: int,
|
||||||
text_encoder: torch.nn.Module,
|
text_encoder: torch.nn.Module,
|
||||||
llm: torch.nn.Module,
|
llm: torch.nn.Module,
|
||||||
|
sampling: Callable,
|
||||||
length_normalized_loss: bool = True,
|
length_normalized_loss: bool = True,
|
||||||
lsm_weight: float = 0.0,
|
lsm_weight: float = 0.0,
|
||||||
spk_embed_dim: int = 192,
|
spk_embed_dim: int = 192,
|
||||||
@@ -47,8 +58,9 @@ class TransformerLM(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 2. build speech token language model related modules
|
# 2. build speech token language model related modules
|
||||||
self.sos_eos = 0
|
self.sos = 0
|
||||||
self.task_id = 1
|
self.task_id = 1
|
||||||
|
self.eos_token = self.speech_token_size
|
||||||
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
|
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
|
||||||
@@ -63,6 +75,9 @@ class TransformerLM(torch.nn.Module):
|
|||||||
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
||||||
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
||||||
|
|
||||||
|
# 4. sampling method
|
||||||
|
self.sampling = sampling
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
text: torch.Tensor,
|
text: torch.Tensor,
|
||||||
@@ -73,10 +88,11 @@ class TransformerLM(torch.nn.Module):
|
|||||||
encoder_out = self.text_encoder_affine_layer(encoder_out)
|
encoder_out = self.text_encoder_affine_layer(encoder_out)
|
||||||
return encoder_out, encoder_out_lens
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
||||||
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||||
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
|
lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
||||||
|
for i in range(len(text_token))]
|
||||||
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
||||||
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
||||||
return lm_input, lm_input_len
|
return lm_input, lm_input_len
|
||||||
@@ -100,7 +116,8 @@ class TransformerLM(torch.nn.Module):
|
|||||||
embedding = batch['embedding'].to(device)
|
embedding = batch['embedding'].to(device)
|
||||||
|
|
||||||
# 1. prepare llm_target
|
# 1. prepare llm_target
|
||||||
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
|
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
|
||||||
|
[self.speech_token_size]) for i in range(text_token.size(0))]
|
||||||
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
||||||
|
|
||||||
# 1. encode text_token
|
# 1. encode text_token
|
||||||
@@ -112,15 +129,16 @@ class TransformerLM(torch.nn.Module):
|
|||||||
embedding = self.spk_embed_affine_layer(embedding)
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
embedding = embedding.unsqueeze(1)
|
embedding = embedding.unsqueeze(1)
|
||||||
|
|
||||||
# 3. eos and task_id
|
# 3. sos and task_id
|
||||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
|
||||||
# 4. encode speech_token
|
# 4. encode speech_token
|
||||||
speech_token = self.speech_embedding(speech_token)
|
speech_token = self.speech_embedding(speech_token)
|
||||||
|
|
||||||
# 5. unpad and pad
|
# 5. unpad and pad
|
||||||
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)
|
lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
|
||||||
|
task_id_emb, speech_token, speech_token_len)
|
||||||
|
|
||||||
# 6. run lm forward
|
# 6. run lm forward
|
||||||
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||||
@@ -132,16 +150,18 @@ class TransformerLM(torch.nn.Module):
|
|||||||
def sampling_ids(
|
def sampling_ids(
|
||||||
self,
|
self,
|
||||||
weighted_scores: torch.Tensor,
|
weighted_scores: torch.Tensor,
|
||||||
sampling: Union[bool, int, float] = True,
|
decoded_tokens: List,
|
||||||
beam_size: int = 1,
|
sampling: int,
|
||||||
ignore_eos: bool = True,
|
ignore_eos: bool = True,
|
||||||
):
|
):
|
||||||
|
num_trials, max_trials = 0, 100
|
||||||
while True:
|
while True:
|
||||||
prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
|
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||||
top_ids = prob.multinomial(beam_size, replacement=True)
|
if (not ignore_eos) or (top_ids < self.speech_token_size):
|
||||||
top_ids = indices[top_ids]
|
|
||||||
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
|
||||||
break
|
break
|
||||||
|
num_trials += 1
|
||||||
|
if num_trials > max_trials:
|
||||||
|
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
||||||
return top_ids
|
return top_ids
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@@ -154,11 +174,11 @@ class TransformerLM(torch.nn.Module):
|
|||||||
prompt_speech_token: torch.Tensor,
|
prompt_speech_token: torch.Tensor,
|
||||||
prompt_speech_token_len: torch.Tensor,
|
prompt_speech_token_len: torch.Tensor,
|
||||||
embedding: torch.Tensor,
|
embedding: torch.Tensor,
|
||||||
beam_size: int = 1,
|
|
||||||
sampling: int = 25,
|
sampling: int = 25,
|
||||||
max_token_text_ratio: float = 20,
|
max_token_text_ratio: float = 20,
|
||||||
min_token_text_ratio: float = 2,
|
min_token_text_ratio: float = 2,
|
||||||
) -> torch.Tensor:
|
uuid: str = '',
|
||||||
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
device = text.device
|
device = text.device
|
||||||
text = torch.concat([prompt_text, text], dim=1)
|
text = torch.concat([prompt_text, text], dim=1)
|
||||||
text_len += prompt_text_len
|
text_len += prompt_text_len
|
||||||
@@ -173,16 +193,16 @@ class TransformerLM(torch.nn.Module):
|
|||||||
embedding = self.spk_embed_affine_layer(embedding)
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
embedding = embedding.unsqueeze(dim=1)
|
embedding = embedding.unsqueeze(dim=1)
|
||||||
else:
|
else:
|
||||||
embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
|
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
||||||
|
|
||||||
# 3. concat llm_input
|
# 3. concat llm_input
|
||||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
if prompt_speech_token_len != 0:
|
if prompt_speech_token_len != 0:
|
||||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||||
else:
|
else:
|
||||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||||
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||||
|
|
||||||
# 4. cal min/max_length
|
# 4. cal min/max_length
|
||||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||||
@@ -193,14 +213,548 @@ class TransformerLM(torch.nn.Module):
|
|||||||
offset = 0
|
offset = 0
|
||||||
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
||||||
for i in range(max_len):
|
for i in range(max_len):
|
||||||
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
|
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
|
||||||
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
|
att_cache=att_cache, cnn_cache=cnn_cache,
|
||||||
|
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
||||||
|
device=lm_input.device)).to(torch.bool))
|
||||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
|
||||||
if top_ids == self.speech_token_size:
|
if top_ids == self.eos_token:
|
||||||
break
|
break
|
||||||
|
# in stream mode, yield token one by one
|
||||||
|
yield top_ids
|
||||||
out_tokens.append(top_ids)
|
out_tokens.append(top_ids)
|
||||||
offset += lm_input.size(1)
|
offset += lm_input.size(1)
|
||||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
|
||||||
return torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
|
||||||
|
class Qwen2Encoder(torch.nn.Module):
|
||||||
|
def __init__(self, pretrain_path):
|
||||||
|
super().__init__()
|
||||||
|
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
|
||||||
|
|
||||||
|
def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
|
||||||
|
T = xs.size(1)
|
||||||
|
masks = ~make_pad_mask(xs_lens, T)
|
||||||
|
outs = self.model(
|
||||||
|
inputs_embeds=xs,
|
||||||
|
attention_mask=masks,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
return outs.hidden_states[-1], masks.unsqueeze(1)
|
||||||
|
|
||||||
|
def forward_one_step(self, xs, masks, cache=None):
|
||||||
|
input_masks = masks[:, -1, :]
|
||||||
|
outs = self.model(
|
||||||
|
inputs_embeds=xs,
|
||||||
|
attention_mask=input_masks,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=True,
|
||||||
|
use_cache=True,
|
||||||
|
past_key_values=cache,
|
||||||
|
)
|
||||||
|
xs = outs.hidden_states[-1]
|
||||||
|
new_cache = outs.past_key_values
|
||||||
|
return xs, new_cache
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2LM(TransformerLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_input_size: int,
|
||||||
|
llm_output_size: int,
|
||||||
|
speech_token_size: int,
|
||||||
|
llm: torch.nn.Module,
|
||||||
|
sampling: Callable,
|
||||||
|
length_normalized_loss: bool = True,
|
||||||
|
lsm_weight: float = 0.0,
|
||||||
|
mix_ratio: List[int] = [5, 15],
|
||||||
|
):
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
self.llm_input_size = llm_input_size
|
||||||
|
self.llm_output_size = llm_output_size
|
||||||
|
self.speech_token_size = speech_token_size
|
||||||
|
# 2. build speech token language model related modules
|
||||||
|
self.sos = 0
|
||||||
|
self.task_id = 1
|
||||||
|
self.eos_token = speech_token_size
|
||||||
|
self.fill_token = speech_token_size + 2
|
||||||
|
|
||||||
|
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
||||||
|
self.llm = llm
|
||||||
|
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
|
||||||
|
self.criterion_ce = LabelSmoothingLoss(
|
||||||
|
size=speech_token_size + 3,
|
||||||
|
padding_idx=IGNORE_ID,
|
||||||
|
smoothing=lsm_weight,
|
||||||
|
normalize_length=length_normalized_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. [Optional] build speech token related modules
|
||||||
|
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
|
||||||
|
|
||||||
|
# 4. sampling method
|
||||||
|
self.sampling = sampling
|
||||||
|
self.mix_ratio = mix_ratio
|
||||||
|
|
||||||
|
# 5. vllm related
|
||||||
|
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
||||||
|
self.vllm_output_queue = {}
|
||||||
|
if online_feature is True:
|
||||||
|
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
|
||||||
|
|
||||||
|
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
|
||||||
|
lm_target, lm_input = [], []
|
||||||
|
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||||
|
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||||
|
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
|
||||||
|
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
|
||||||
|
# NOTE add instruct_token in CosyVoice3
|
||||||
|
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
|
||||||
|
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
|
||||||
|
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
|
||||||
|
else:
|
||||||
|
instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
|
||||||
|
instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
|
||||||
|
instruct_token_len = torch.zeros(len(text_token)).to(text_token_len)
|
||||||
|
for i in range(len(text_token)):
|
||||||
|
# bistream sequence
|
||||||
|
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
||||||
|
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
|
||||||
|
this_lm_target += [IGNORE_ID] * instruct_token_len[i]
|
||||||
|
this_lm_input.append(instruct_token_emb[i])
|
||||||
|
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
||||||
|
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
||||||
|
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
||||||
|
if len(this_text_token) == self.mix_ratio[0]:
|
||||||
|
assert len(this_speech_token) == self.mix_ratio[1]
|
||||||
|
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
|
||||||
|
this_lm_target += this_speech_token
|
||||||
|
this_lm_target.append(self.fill_token)
|
||||||
|
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
|
||||||
|
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
|
||||||
|
else:
|
||||||
|
this_lm_target += [-1] * len(this_text_token)
|
||||||
|
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
|
||||||
|
this_lm_target.append(self.eos_token)
|
||||||
|
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
|
||||||
|
this_lm_input.append(task_id_emb.squeeze(dim=0))
|
||||||
|
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
|
||||||
|
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
||||||
|
# unistream sequence
|
||||||
|
else:
|
||||||
|
this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
|
||||||
|
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
|
||||||
|
lm_target.append(this_lm_target)
|
||||||
|
lm_input.append(this_lm_input)
|
||||||
|
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
||||||
|
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
||||||
|
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
|
||||||
|
return lm_target, lm_input, lm_input_len
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: dict,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text: (B, L, D)
|
||||||
|
text_lengths: (B,)
|
||||||
|
audio: (B, T, N) or (B, T)
|
||||||
|
audio_lengths: (B,)
|
||||||
|
"""
|
||||||
|
text_token = batch['text_token'].to(device)
|
||||||
|
text_token_len = batch['text_token_len'].to(device)
|
||||||
|
if 'speech_token' not in batch:
|
||||||
|
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||||
|
else:
|
||||||
|
speech_token = batch['speech_token'].to(device)
|
||||||
|
speech_token_len = batch['speech_token_len'].to(device)
|
||||||
|
|
||||||
|
# 1. encode text_token
|
||||||
|
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||||
|
|
||||||
|
# 3. sos and task_id
|
||||||
|
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
|
||||||
|
# 2. encode speech_token
|
||||||
|
speech_token_emb = self.speech_embedding(speech_token)
|
||||||
|
|
||||||
|
# 3. prepare llm_input/target
|
||||||
|
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
||||||
|
speech_token, speech_token_emb, speech_token_len)
|
||||||
|
lm_target = lm_target.to(device)
|
||||||
|
|
||||||
|
# 4. run lm forward
|
||||||
|
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||||
|
logits = self.llm_decoder(lm_output)
|
||||||
|
loss = self.criterion_ce(logits, lm_target.to(device))
|
||||||
|
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
||||||
|
return {'loss': loss, 'acc': acc}
|
||||||
|
|
||||||
|
def forward_dpo(
|
||||||
|
self,
|
||||||
|
batch: dict,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, Optional[torch.Tensor]]:
|
||||||
|
text_token = batch['text_token'].to(device)
|
||||||
|
text_token_len = batch['text_token_len'].to(device)
|
||||||
|
speech_token = batch['speech_token'].to(device)
|
||||||
|
speech_token_len = batch['speech_token_len'].to(device)
|
||||||
|
reject_speech_token = batch['reject_speech_token'].to(device)
|
||||||
|
reject_speech_token_len = batch['reject_speech_token_len'].to(device)
|
||||||
|
|
||||||
|
# 1. encode text_token
|
||||||
|
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||||
|
|
||||||
|
# 3. sos and task_id
|
||||||
|
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
|
||||||
|
# 2. encode speech_token
|
||||||
|
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||||
|
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
|
||||||
|
speech_token_combined = speech_token + reject_speech_token
|
||||||
|
speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
|
||||||
|
speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
|
||||||
|
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
|
||||||
|
|
||||||
|
# 3. prepare llm_input/target
|
||||||
|
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
|
||||||
|
task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
|
||||||
|
lm_target = lm_target.to(device)
|
||||||
|
|
||||||
|
# 4. run lm forward
|
||||||
|
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||||
|
logits = self.llm_decoder(lm_output)
|
||||||
|
chosen_logits = logits[:text_token.shape[0]]
|
||||||
|
rejected_logits = logits[text_token.shape[0]:]
|
||||||
|
chosen_lm_target = lm_target[:text_token.shape[0]]
|
||||||
|
rejected_lm_target = lm_target[text_token.shape[0]:]
|
||||||
|
loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
|
||||||
|
acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
|
||||||
|
|
||||||
|
# 5. calculate dpo logits
|
||||||
|
chosen_lm_mask = chosen_lm_target == IGNORE_ID
|
||||||
|
rejected_lm_mask = rejected_lm_target == IGNORE_ID
|
||||||
|
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||||
|
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||||
|
chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
|
||||||
|
rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
|
||||||
|
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_len: torch.Tensor,
|
||||||
|
prompt_text: torch.Tensor,
|
||||||
|
prompt_text_len: torch.Tensor,
|
||||||
|
prompt_speech_token: torch.Tensor,
|
||||||
|
prompt_speech_token_len: torch.Tensor,
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
sampling: int = 25,
|
||||||
|
max_token_text_ratio: float = 20,
|
||||||
|
min_token_text_ratio: float = 2,
|
||||||
|
uuid: str = '',
|
||||||
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
device = text.device
|
||||||
|
text = torch.concat([prompt_text, text], dim=1)
|
||||||
|
text_len += prompt_text_len
|
||||||
|
text = self.llm.model.model.embed_tokens(text)
|
||||||
|
|
||||||
|
# 3. concat llm_input
|
||||||
|
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
if prompt_speech_token_len != 0:
|
||||||
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||||
|
else:
|
||||||
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||||
|
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||||
|
|
||||||
|
# 4. cal min/max_length
|
||||||
|
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||||
|
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||||
|
|
||||||
|
# 5. step by step decode
|
||||||
|
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
|
||||||
|
yield token
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
|
||||||
|
if hasattr(self, 'vllm'):
|
||||||
|
from vllm import SamplingParams, RequestOutput
|
||||||
|
sampling_params = SamplingParams(top_k=sampling,
|
||||||
|
stop_token_ids=self.stop_token_ids,
|
||||||
|
min_tokens=min_len,
|
||||||
|
max_tokens=max_len)
|
||||||
|
with self.lock:
|
||||||
|
self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
|
||||||
|
self.vllm_output_queue[uuid] = queue.Queue()
|
||||||
|
out_tokens = []
|
||||||
|
while True:
|
||||||
|
with self.lock:
|
||||||
|
if self.vllm_output_queue[uuid].empty() is True:
|
||||||
|
request_outputs: List[RequestOutput] = self.vllm.step()
|
||||||
|
for request_output in request_outputs:
|
||||||
|
top_ids = list(request_output.outputs[0].token_ids)[-1]
|
||||||
|
self.vllm_output_queue[request_output.request_id].put(top_ids)
|
||||||
|
if self.vllm_output_queue[uuid].empty() is False:
|
||||||
|
top_ids = self.vllm_output_queue[uuid].get()
|
||||||
|
if top_ids in self.stop_token_ids:
|
||||||
|
break
|
||||||
|
# in stream mode, yield token one by one
|
||||||
|
yield top_ids
|
||||||
|
out_tokens.append(top_ids)
|
||||||
|
if len(out_tokens) == max_len:
|
||||||
|
break
|
||||||
|
time.sleep(0.001)
|
||||||
|
with self.lock:
|
||||||
|
self.vllm_output_queue.pop(uuid)
|
||||||
|
else:
|
||||||
|
out_tokens = []
|
||||||
|
cache = None
|
||||||
|
for i in range(max_len):
|
||||||
|
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||||
|
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||||
|
cache=cache)
|
||||||
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||||
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
|
||||||
|
if top_ids in self.stop_token_ids:
|
||||||
|
break
|
||||||
|
# in stream mode, yield token one by one
|
||||||
|
yield top_ids
|
||||||
|
out_tokens.append(top_ids)
|
||||||
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference_bistream(
|
||||||
|
self,
|
||||||
|
text: Generator,
|
||||||
|
prompt_text: torch.Tensor,
|
||||||
|
prompt_text_len: torch.Tensor,
|
||||||
|
prompt_speech_token: torch.Tensor,
|
||||||
|
prompt_speech_token_len: torch.Tensor,
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
sampling: int = 25,
|
||||||
|
max_token_text_ratio: float = 20,
|
||||||
|
min_token_text_ratio: float = 2,
|
||||||
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
|
||||||
|
device = prompt_text.device
|
||||||
|
# 1. prepare input
|
||||||
|
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
if prompt_speech_token_len != 0:
|
||||||
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||||
|
else:
|
||||||
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
||||||
|
lm_input = torch.concat([sos_emb], dim=1)
|
||||||
|
|
||||||
|
# 2. iterate text
|
||||||
|
out_tokens = []
|
||||||
|
cache = None
|
||||||
|
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
||||||
|
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
||||||
|
next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
|
||||||
|
for this_text in text:
|
||||||
|
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
||||||
|
# prompt_speech_token_emb not empty, try append to lm_input
|
||||||
|
while prompt_speech_token_emb.size(1) != 0:
|
||||||
|
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||||
|
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
|
||||||
|
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
|
||||||
|
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
|
||||||
|
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
|
||||||
|
else:
|
||||||
|
logging.info('not enough text token to decode, wait for more')
|
||||||
|
break
|
||||||
|
# no prompt_speech_token_emb remain, can decode some speech token
|
||||||
|
if prompt_speech_token_emb.size(1) == 0:
|
||||||
|
if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
||||||
|
logging.info('get fill token, need to append more text token')
|
||||||
|
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||||
|
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
||||||
|
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
||||||
|
if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
|
||||||
|
lm_input = lm_input_text
|
||||||
|
else:
|
||||||
|
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
||||||
|
text_cache = text_cache[:, self.mix_ratio[0]:]
|
||||||
|
else:
|
||||||
|
logging.info('not enough text token to decode, wait for more')
|
||||||
|
continue
|
||||||
|
while True:
|
||||||
|
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
||||||
|
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||||
|
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
||||||
|
cache=cache)
|
||||||
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||||
|
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
||||||
|
top_ids = self.fill_token
|
||||||
|
next_fill_index += (self.mix_ratio[1] + 1)
|
||||||
|
else:
|
||||||
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
|
||||||
|
if top_ids == self.fill_token:
|
||||||
|
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
||||||
|
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
||||||
|
out_tokens.append(top_ids)
|
||||||
|
if top_ids >= self.speech_token_size:
|
||||||
|
if top_ids == self.fill_token:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError('should not get token {}'.format(top_ids))
|
||||||
|
yield top_ids
|
||||||
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
|
||||||
|
# 3. final decode
|
||||||
|
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
|
||||||
|
logging.info('no more text token, decode until met eos')
|
||||||
|
while True:
|
||||||
|
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
||||||
|
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||||
|
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
||||||
|
cache=cache)
|
||||||
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||||
|
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
|
||||||
|
out_tokens.append(top_ids)
|
||||||
|
if top_ids >= self.speech_token_size:
|
||||||
|
if top_ids == self.eos_token:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError('should not get token {}'.format(top_ids))
|
||||||
|
# in stream mode, yield token one by one
|
||||||
|
yield top_ids
|
||||||
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice3LM(Qwen2LM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_input_size: int,
|
||||||
|
llm_output_size: int,
|
||||||
|
speech_token_size: int,
|
||||||
|
llm: torch.nn.Module,
|
||||||
|
sampling: Callable,
|
||||||
|
length_normalized_loss: bool = True,
|
||||||
|
lsm_weight: float = 0.0,
|
||||||
|
mix_ratio: List[int] = [5, 15],
|
||||||
|
):
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
self.llm_input_size = llm_input_size
|
||||||
|
self.llm_output_size = llm_output_size
|
||||||
|
self.speech_token_size = speech_token_size
|
||||||
|
# 2. build speech token language model related modules
|
||||||
|
self.sos = speech_token_size + 0
|
||||||
|
self.eos_token = speech_token_size + 1
|
||||||
|
self.task_id = speech_token_size + 2
|
||||||
|
self.fill_token = speech_token_size + 3
|
||||||
|
|
||||||
|
self.llm = llm
|
||||||
|
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
|
||||||
|
self.criterion_ce = LabelSmoothingLoss(
|
||||||
|
size=speech_token_size + 200,
|
||||||
|
padding_idx=IGNORE_ID,
|
||||||
|
smoothing=lsm_weight,
|
||||||
|
normalize_length=length_normalized_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. [Optional] build speech token related modules
|
||||||
|
self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
|
||||||
|
|
||||||
|
# 4. sampling method
|
||||||
|
self.sampling = sampling
|
||||||
|
self.mix_ratio = mix_ratio
|
||||||
|
|
||||||
|
# 5. vllm related
|
||||||
|
self.stop_token_ids = [speech_token_size + i for i in range(200)]
|
||||||
|
self.vllm_output_queue = {}
|
||||||
|
if online_feature is True:
|
||||||
|
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: dict,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text: (B, L, D)
|
||||||
|
text_lengths: (B,)
|
||||||
|
audio: (B, T, N) or (B, T)
|
||||||
|
audio_lengths: (B,)
|
||||||
|
"""
|
||||||
|
text_token = batch['text_token'].to(device)
|
||||||
|
text_token_len = batch['text_token_len'].to(device)
|
||||||
|
if 'speech_token' not in batch:
|
||||||
|
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||||
|
else:
|
||||||
|
speech_token = batch['speech_token'].to(device)
|
||||||
|
speech_token_len = batch['speech_token_len'].to(device)
|
||||||
|
|
||||||
|
# NOTE should append instruct_token to sequence, not implemented yet
|
||||||
|
instruct_token = batch['instruct_token'].to(device)
|
||||||
|
instruct_token_len = batch['instruct_token_len'].to(device)
|
||||||
|
|
||||||
|
# 1. encode text_token
|
||||||
|
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||||
|
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
|
||||||
|
|
||||||
|
# 3. sos and task_id
|
||||||
|
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
|
||||||
|
# 2. encode speech_token
|
||||||
|
speech_token_emb = self.speech_embedding(speech_token)
|
||||||
|
|
||||||
|
# 3. prepare llm_input/target
|
||||||
|
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
||||||
|
speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
|
||||||
|
lm_target = lm_target.to(device)
|
||||||
|
|
||||||
|
# 4. run lm forward
|
||||||
|
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||||
|
logits = self.llm_decoder(lm_output)
|
||||||
|
loss = self.criterion_ce(logits, lm_target.to(device))
|
||||||
|
acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
|
||||||
|
return {'loss': loss, 'acc': acc}
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_len: torch.Tensor,
|
||||||
|
prompt_text: torch.Tensor,
|
||||||
|
prompt_text_len: torch.Tensor,
|
||||||
|
prompt_speech_token: torch.Tensor,
|
||||||
|
prompt_speech_token_len: torch.Tensor,
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
sampling: int = 25,
|
||||||
|
max_token_text_ratio: float = 20,
|
||||||
|
min_token_text_ratio: float = 2,
|
||||||
|
uuid: str = '',
|
||||||
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
device = text.device
|
||||||
|
text = torch.concat([prompt_text, text], dim=1)
|
||||||
|
text_len += prompt_text_len
|
||||||
|
text = self.llm.model.model.embed_tokens(text)
|
||||||
|
|
||||||
|
# 3. concat llm_input
|
||||||
|
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||||
|
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||||
|
if prompt_speech_token_len != 0:
|
||||||
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||||
|
else:
|
||||||
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||||
|
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||||
|
|
||||||
|
# 4. cal min/max_length
|
||||||
|
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||||
|
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||||
|
|
||||||
|
# 5. step by step decode
|
||||||
|
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
|
||||||
|
yield token
|
||||||
|
|||||||
58836
cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken
Normal file
58836
cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken
Normal file
File diff suppressed because it is too large
Load Diff
327
cosyvoice/tokenizer/tokenizer.py
Normal file
327
cosyvoice/tokenizer/tokenizer.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from whisper.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
LANGUAGES = {
|
||||||
|
"en": "english",
|
||||||
|
"zh": "chinese",
|
||||||
|
"de": "german",
|
||||||
|
"es": "spanish",
|
||||||
|
"ru": "russian",
|
||||||
|
"ko": "korean",
|
||||||
|
"fr": "french",
|
||||||
|
"ja": "japanese",
|
||||||
|
"pt": "portuguese",
|
||||||
|
"tr": "turkish",
|
||||||
|
"pl": "polish",
|
||||||
|
"ca": "catalan",
|
||||||
|
"nl": "dutch",
|
||||||
|
"ar": "arabic",
|
||||||
|
"sv": "swedish",
|
||||||
|
"it": "italian",
|
||||||
|
"id": "indonesian",
|
||||||
|
"hi": "hindi",
|
||||||
|
"fi": "finnish",
|
||||||
|
"vi": "vietnamese",
|
||||||
|
"he": "hebrew",
|
||||||
|
"uk": "ukrainian",
|
||||||
|
"el": "greek",
|
||||||
|
"ms": "malay",
|
||||||
|
"cs": "czech",
|
||||||
|
"ro": "romanian",
|
||||||
|
"da": "danish",
|
||||||
|
"hu": "hungarian",
|
||||||
|
"ta": "tamil",
|
||||||
|
"no": "norwegian",
|
||||||
|
"th": "thai",
|
||||||
|
"ur": "urdu",
|
||||||
|
"hr": "croatian",
|
||||||
|
"bg": "bulgarian",
|
||||||
|
"lt": "lithuanian",
|
||||||
|
"la": "latin",
|
||||||
|
"mi": "maori",
|
||||||
|
"ml": "malayalam",
|
||||||
|
"cy": "welsh",
|
||||||
|
"sk": "slovak",
|
||||||
|
"te": "telugu",
|
||||||
|
"fa": "persian",
|
||||||
|
"lv": "latvian",
|
||||||
|
"bn": "bengali",
|
||||||
|
"sr": "serbian",
|
||||||
|
"az": "azerbaijani",
|
||||||
|
"sl": "slovenian",
|
||||||
|
"kn": "kannada",
|
||||||
|
"et": "estonian",
|
||||||
|
"mk": "macedonian",
|
||||||
|
"br": "breton",
|
||||||
|
"eu": "basque",
|
||||||
|
"is": "icelandic",
|
||||||
|
"hy": "armenian",
|
||||||
|
"ne": "nepali",
|
||||||
|
"mn": "mongolian",
|
||||||
|
"bs": "bosnian",
|
||||||
|
"kk": "kazakh",
|
||||||
|
"sq": "albanian",
|
||||||
|
"sw": "swahili",
|
||||||
|
"gl": "galician",
|
||||||
|
"mr": "marathi",
|
||||||
|
"pa": "punjabi",
|
||||||
|
"si": "sinhala",
|
||||||
|
"km": "khmer",
|
||||||
|
"sn": "shona",
|
||||||
|
"yo": "yoruba",
|
||||||
|
"so": "somali",
|
||||||
|
"af": "afrikaans",
|
||||||
|
"oc": "occitan",
|
||||||
|
"ka": "georgian",
|
||||||
|
"be": "belarusian",
|
||||||
|
"tg": "tajik",
|
||||||
|
"sd": "sindhi",
|
||||||
|
"gu": "gujarati",
|
||||||
|
"am": "amharic",
|
||||||
|
"yi": "yiddish",
|
||||||
|
"lo": "lao",
|
||||||
|
"uz": "uzbek",
|
||||||
|
"fo": "faroese",
|
||||||
|
"ht": "haitian creole",
|
||||||
|
"ps": "pashto",
|
||||||
|
"tk": "turkmen",
|
||||||
|
"nn": "nynorsk",
|
||||||
|
"mt": "maltese",
|
||||||
|
"sa": "sanskrit",
|
||||||
|
"lb": "luxembourgish",
|
||||||
|
"my": "myanmar",
|
||||||
|
"bo": "tibetan",
|
||||||
|
"tl": "tagalog",
|
||||||
|
"mg": "malagasy",
|
||||||
|
"as": "assamese",
|
||||||
|
"tt": "tatar",
|
||||||
|
"haw": "hawaiian",
|
||||||
|
"ln": "lingala",
|
||||||
|
"ha": "hausa",
|
||||||
|
"ba": "bashkir",
|
||||||
|
"jw": "javanese",
|
||||||
|
"su": "sundanese",
|
||||||
|
"yue": "cantonese",
|
||||||
|
"minnan": "minnan",
|
||||||
|
"wuyu": "wuyu",
|
||||||
|
"dialect": "dialect",
|
||||||
|
"zh/en": "zh/en",
|
||||||
|
"en/zh": "en/zh",
|
||||||
|
}
|
||||||
|
|
||||||
|
# language code lookup by name, with a few language aliases
|
||||||
|
TO_LANGUAGE_CODE = {
|
||||||
|
**{language: code for code, language in LANGUAGES.items()},
|
||||||
|
"burmese": "my",
|
||||||
|
"valencian": "ca",
|
||||||
|
"flemish": "nl",
|
||||||
|
"haitian": "ht",
|
||||||
|
"letzeburgesch": "lb",
|
||||||
|
"pushto": "ps",
|
||||||
|
"panjabi": "pa",
|
||||||
|
"moldavian": "ro",
|
||||||
|
"moldovan": "ro",
|
||||||
|
"sinhalese": "si",
|
||||||
|
"castilian": "es",
|
||||||
|
"mandarin": "zh",
|
||||||
|
}
|
||||||
|
|
||||||
|
AUDIO_EVENT = {
|
||||||
|
"ASR": "ASR",
|
||||||
|
"AED": "AED",
|
||||||
|
"SER": "SER",
|
||||||
|
"Speech": "Speech",
|
||||||
|
"/Speech": "/Speech",
|
||||||
|
"BGM": "BGM",
|
||||||
|
"/BGM": "/BGM",
|
||||||
|
"Laughter": "Laughter",
|
||||||
|
"/Laughter": "/Laughter",
|
||||||
|
"Applause": "Applause",
|
||||||
|
"/Applause": "/Applause",
|
||||||
|
}
|
||||||
|
|
||||||
|
EMOTION = {
|
||||||
|
"HAPPY": "HAPPY",
|
||||||
|
"SAD": "SAD",
|
||||||
|
"ANGRY": "ANGRY",
|
||||||
|
"NEUTRAL": "NEUTRAL",
|
||||||
|
}
|
||||||
|
|
||||||
|
TTS_Vocal_Token = {
|
||||||
|
"TTS/B": "TTS/B",
|
||||||
|
"TTS/O": "TTS/O",
|
||||||
|
"TTS/Q": "TTS/Q",
|
||||||
|
"TTS/A": "TTS/A",
|
||||||
|
"TTS/CO": "TTS/CO",
|
||||||
|
"TTS/CL": "TTS/CL",
|
||||||
|
"TTS/H": "TTS/H",
|
||||||
|
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||||
|
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||||
|
ranks = {
|
||||||
|
base64.b64decode(token): int(rank)
|
||||||
|
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||||
|
}
|
||||||
|
n_vocab = len(ranks)
|
||||||
|
special_tokens = {}
|
||||||
|
|
||||||
|
specials = [
|
||||||
|
"<|endoftext|>",
|
||||||
|
"<|startoftranscript|>",
|
||||||
|
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||||
|
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
||||||
|
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
||||||
|
"<|translate|>",
|
||||||
|
"<|transcribe|>",
|
||||||
|
"<|startoflm|>",
|
||||||
|
"<|startofprev|>",
|
||||||
|
"<|nospeech|>",
|
||||||
|
"<|notimestamps|>",
|
||||||
|
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
|
||||||
|
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
|
||||||
|
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||||
|
]
|
||||||
|
|
||||||
|
for token in specials:
|
||||||
|
special_tokens[token] = n_vocab
|
||||||
|
n_vocab += 1
|
||||||
|
|
||||||
|
return tiktoken.Encoding(
|
||||||
|
name=os.path.basename(vocab_path),
|
||||||
|
explicit_n_vocab=n_vocab,
|
||||||
|
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||||
|
mergeable_ranks=ranks,
|
||||||
|
special_tokens=special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_tokenizer(
|
||||||
|
multilingual: bool,
|
||||||
|
*,
|
||||||
|
num_languages: int = 99,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||||
|
) -> Tokenizer:
|
||||||
|
if language is not None:
|
||||||
|
language = language.lower()
|
||||||
|
if language not in LANGUAGES:
|
||||||
|
if language in TO_LANGUAGE_CODE:
|
||||||
|
language = TO_LANGUAGE_CODE[language]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported language: {language}")
|
||||||
|
|
||||||
|
if multilingual:
|
||||||
|
encoding_name = "multilingual_zh_ja_yue_char_del"
|
||||||
|
language = language or "en"
|
||||||
|
task = task or "transcribe"
|
||||||
|
else:
|
||||||
|
encoding_name = "gpt2"
|
||||||
|
language = None
|
||||||
|
task = None
|
||||||
|
|
||||||
|
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
||||||
|
|
||||||
|
return Tokenizer(
|
||||||
|
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice2Tokenizer():
|
||||||
|
def __init__(self, token_path, skip_special_tokens=True):
|
||||||
|
super().__init__()
|
||||||
|
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
||||||
|
special_tokens = {
|
||||||
|
'eos_token': '<|endoftext|>',
|
||||||
|
'pad_token': '<|endoftext|>',
|
||||||
|
'additional_special_tokens': [
|
||||||
|
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||||
|
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||||
|
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||||
|
'[quick_breath]',
|
||||||
|
"<laughter>", "</laughter>",
|
||||||
|
"[hissing]", "[sigh]", "[vocalized-noise]",
|
||||||
|
"[lipsmack]", "[mn]"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
self.special_tokens = special_tokens
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
||||||
|
self.tokenizer.add_special_tokens(special_tokens)
|
||||||
|
self.skip_special_tokens = skip_special_tokens
|
||||||
|
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
tokens = self.tokenizer([text], return_tensors="pt")
|
||||||
|
tokens = tokens["input_ids"][0].cpu().tolist()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
tokens = torch.tensor(tokens, dtype=torch.int64)
|
||||||
|
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
|
||||||
|
def __init__(self, token_path, skip_special_tokens=True):
|
||||||
|
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
||||||
|
special_tokens = {
|
||||||
|
'eos_token': '<|endoftext|>',
|
||||||
|
'pad_token': '<|endoftext|>',
|
||||||
|
'additional_special_tokens': [
|
||||||
|
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||||
|
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||||
|
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||||
|
'[quick_breath]',
|
||||||
|
"<laughter>", "</laughter>",
|
||||||
|
"[hissing]", "[sigh]", "[vocalized-noise]",
|
||||||
|
"[lipsmack]", "[mn]", "<|endofsystem|>",
|
||||||
|
"[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
|
||||||
|
"[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
|
||||||
|
"[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
|
||||||
|
"[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
|
||||||
|
"[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
|
||||||
|
"[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
|
||||||
|
"[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
|
||||||
|
"[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
|
||||||
|
"[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
|
||||||
|
"[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
|
||||||
|
"[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
|
||||||
|
"[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
|
||||||
|
"[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
|
||||||
|
"[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
|
||||||
|
"[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
|
||||||
|
"[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
|
||||||
|
"[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
|
||||||
|
"[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
|
||||||
|
"[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
|
||||||
|
"[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
self.special_tokens = special_tokens
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
||||||
|
self.tokenizer.add_special_tokens(special_tokens)
|
||||||
|
self.skip_special_tokens = skip_special_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_qwen_tokenizer(
|
||||||
|
token_path: str,
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
version: str = 'cosyvoice2'
|
||||||
|
):
|
||||||
|
if version == 'cosyvoice2':
|
||||||
|
return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
||||||
|
elif version == 'cosyvoice3':
|
||||||
|
return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
@@ -222,7 +222,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|||||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||||
|
|
||||||
def rel_shift(self, x):
|
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Compute relative positional encoding.
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -233,10 +233,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|||||||
torch.Tensor: Output tensor.
|
torch.Tensor: Output tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype)
|
||||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||||
|
|
||||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
x_padded = x_padded.view(x.size()[0],
|
||||||
|
x.size()[1],
|
||||||
|
x.size(3) + 1, x.size(2))
|
||||||
x = x_padded[:, :, 1:].view_as(x)[
|
x = x_padded[:, :, 1:].view_as(x)[
|
||||||
:, :, :, : x.size(-1) // 2 + 1
|
:, :, :, : x.size(-1) // 2 + 1
|
||||||
] # only keep the positions from 0 to time2
|
] # only keep the positions from 0 to time2
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
@@ -143,3 +144,115 @@ class ConvolutionModule(nn.Module):
|
|||||||
x.masked_fill_(~mask_pad, 0.0)
|
x.masked_fill_(~mask_pad, 0.0)
|
||||||
|
|
||||||
return x.transpose(1, 2), new_cache
|
return x.transpose(1, 2), new_cache
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(Xiang Lyu) causal conv module used in convolution-based vocoder
|
||||||
|
class CausalConv1d(torch.nn.Conv1d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
padding_mode: str = 'zeros',
|
||||||
|
causal_type: str = 'left',
|
||||||
|
device=None,
|
||||||
|
dtype=None
|
||||||
|
) -> None:
|
||||||
|
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
||||||
|
kernel_size, stride=1,
|
||||||
|
padding=0, dilation=dilation,
|
||||||
|
groups=groups, bias=bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
assert stride == 1
|
||||||
|
self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2 + (kernel_size + 1) % 2
|
||||||
|
assert causal_type in ['left', 'right']
|
||||||
|
self.causal_type = causal_type
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor]:
|
||||||
|
input_timestep = x.shape[2]
|
||||||
|
if cache.size(2) == 0:
|
||||||
|
cache = torch.zeros(x.shape[0], x.shape[1], self.causal_padding).to(x)
|
||||||
|
assert cache.size(2) == self.causal_padding
|
||||||
|
if self.causal_type == 'left':
|
||||||
|
x = torch.concat([cache, x], dim=2)
|
||||||
|
else:
|
||||||
|
x = torch.concat([x, cache], dim=2)
|
||||||
|
x = super(CausalConv1d, self).forward(x)
|
||||||
|
assert x.shape[2] == input_timestep
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1dDownSample(torch.nn.Conv1d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
padding_mode: str = 'zeros',
|
||||||
|
device=None,
|
||||||
|
dtype=None
|
||||||
|
) -> None:
|
||||||
|
super(CausalConv1dDownSample, self).__init__(in_channels, out_channels,
|
||||||
|
kernel_size, stride,
|
||||||
|
padding=0, dilation=dilation,
|
||||||
|
groups=groups, bias=bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
assert stride != 1 and dilation == 1
|
||||||
|
assert kernel_size % stride == 0
|
||||||
|
self.causal_padding = stride - 1
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if cache.size(2) == 0:
|
||||||
|
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||||
|
else:
|
||||||
|
assert cache.size(2) == self.causal_padding
|
||||||
|
x = torch.concat([cache, x], dim=2)
|
||||||
|
x = super(CausalConv1dDownSample, self).forward(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1dUpsample(torch.nn.Conv1d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
padding_mode: str = 'zeros',
|
||||||
|
device=None,
|
||||||
|
dtype=None
|
||||||
|
) -> None:
|
||||||
|
super(CausalConv1dUpsample, self).__init__(in_channels, out_channels,
|
||||||
|
kernel_size, 1,
|
||||||
|
padding=0, dilation=dilation,
|
||||||
|
groups=groups, bias=bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
assert dilation == 1
|
||||||
|
self.causal_padding = kernel_size - 1
|
||||||
|
self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
x = self.upsample(x)
|
||||||
|
input_timestep = x.shape[2]
|
||||||
|
if cache.size(2) == 0:
|
||||||
|
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||||
|
else:
|
||||||
|
assert cache.size(2) == self.causal_padding
|
||||||
|
x = torch.concat([cache, x], dim=2)
|
||||||
|
x = super(CausalConv1dUpsample, self).forward(x)
|
||||||
|
assert input_timestep == x.shape[2]
|
||||||
|
return x
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ class TransformerDecoder(torch.nn.Module):
|
|||||||
memory_mask)
|
memory_mask)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.jit.ignore(drop=True)
|
@torch.jit.unused
|
||||||
def forward_layers_checkpointed(self, x: torch.Tensor,
|
def forward_layers_checkpointed(self, x: torch.Tensor,
|
||||||
tgt_mask: torch.Tensor,
|
tgt_mask: torch.Tensor,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(EspnetRelPositionalEncoding, self).__init__()
|
super(EspnetRelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
@@ -221,7 +221,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
def extend_pe(self, x):
|
def extend_pe(self, x: torch.Tensor):
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
@@ -253,7 +253,8 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|||||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
|
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
||||||
|
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -286,8 +287,16 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Corresponding encoding
|
torch.Tensor: Corresponding encoding
|
||||||
"""
|
"""
|
||||||
pos_emb = self.pe[
|
# How to subscript a Union type:
|
||||||
:,
|
# https://github.com/pytorch/pytorch/issues/69434
|
||||||
self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
|
if isinstance(offset, int):
|
||||||
]
|
pos_emb = self.pe[
|
||||||
|
:,
|
||||||
|
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
||||||
|
]
|
||||||
|
elif isinstance(offset, torch.Tensor):
|
||||||
|
pos_emb = self.pe[
|
||||||
|
:,
|
||||||
|
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
||||||
|
]
|
||||||
return pos_emb
|
return pos_emb
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ class BaseEncoder(torch.nn.Module):
|
|||||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||||
return xs
|
return xs
|
||||||
|
|
||||||
@torch.jit.ignore(drop=True)
|
@torch.jit.unused
|
||||||
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
||||||
chunk_masks: torch.Tensor,
|
chunk_masks: torch.Tensor,
|
||||||
pos_emb: torch.Tensor,
|
pos_emb: torch.Tensor,
|
||||||
@@ -180,6 +180,7 @@ class BaseEncoder(torch.nn.Module):
|
|||||||
mask_pad)
|
mask_pad)
|
||||||
return xs
|
return xs
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
def forward_chunk(
|
def forward_chunk(
|
||||||
self,
|
self,
|
||||||
xs: torch.Tensor,
|
xs: torch.Tensor,
|
||||||
@@ -270,6 +271,7 @@ class BaseEncoder(torch.nn.Module):
|
|||||||
|
|
||||||
return (xs, r_att_cache, r_cnn_cache)
|
return (xs, r_att_cache, r_cnn_cache)
|
||||||
|
|
||||||
|
@torch.jit.unused
|
||||||
def forward_chunk_by_chunk(
|
def forward_chunk_by_chunk(
|
||||||
self,
|
self,
|
||||||
xs: torch.Tensor,
|
xs: torch.Tensor,
|
||||||
|
|||||||
@@ -49,8 +49,8 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = self_attn
|
self.self_attn = self_attn
|
||||||
self.feed_forward = feed_forward
|
self.feed_forward = feed_forward
|
||||||
self.norm1 = nn.LayerNorm(size, eps=1e-5)
|
self.norm1 = nn.LayerNorm(size, eps=1e-12)
|
||||||
self.norm2 = nn.LayerNorm(size, eps=1e-5)
|
self.norm2 = nn.LayerNorm(size, eps=1e-12)
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
self.dropout = nn.Dropout(dropout_rate)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
@@ -142,17 +142,17 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.feed_forward = feed_forward
|
self.feed_forward = feed_forward
|
||||||
self.feed_forward_macaron = feed_forward_macaron
|
self.feed_forward_macaron = feed_forward_macaron
|
||||||
self.conv_module = conv_module
|
self.conv_module = conv_module
|
||||||
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
|
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
||||||
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
|
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
||||||
if feed_forward_macaron is not None:
|
if feed_forward_macaron is not None:
|
||||||
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
|
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
||||||
self.ff_scale = 0.5
|
self.ff_scale = 0.5
|
||||||
else:
|
else:
|
||||||
self.ff_scale = 1.0
|
self.ff_scale = 1.0
|
||||||
if self.conv_module is not None:
|
if self.conv_module is not None:
|
||||||
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
|
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
||||||
self.norm_final = nn.LayerNorm(
|
self.norm_final = nn.LayerNorm(
|
||||||
size, eps=1e-5) # for the final output of the block
|
size, eps=1e-12) # for the final output of the block
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
self.dropout = nn.Dropout(dropout_rate)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
|
|||||||
321
cosyvoice/transformer/upsample_encoder.py
Normal file
321
cosyvoice/transformer/upsample_encoder.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
||||||
|
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
||||||
|
# 2024 Alibaba Inc (Xiang Lyu)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||||
|
"""Encoder definition."""
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from cosyvoice.transformer.convolution import ConvolutionModule
|
||||||
|
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
||||||
|
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||||
|
from cosyvoice.utils.class_utils import (
|
||||||
|
COSYVOICE_EMB_CLASSES,
|
||||||
|
COSYVOICE_SUBSAMPLE_CLASSES,
|
||||||
|
COSYVOICE_ATTENTION_CLASSES,
|
||||||
|
COSYVOICE_ACTIVATION_CLASSES,
|
||||||
|
)
|
||||||
|
from cosyvoice.utils.mask import make_pad_mask
|
||||||
|
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample1D(nn.Module):
|
||||||
|
"""A 1D upsampling layer with an optional convolution.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
channels (`int`):
|
||||||
|
number of channels in the inputs and outputs.
|
||||||
|
use_conv (`bool`, default `False`):
|
||||||
|
option to use a convolution.
|
||||||
|
use_conv_transpose (`bool`, default `False`):
|
||||||
|
option to use a convolution transpose.
|
||||||
|
out_channels (`int`, optional):
|
||||||
|
number of output channels. Defaults to `channels`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.stride = stride
|
||||||
|
# In this mode, first repeat interpolate, than conv with stride=1
|
||||||
|
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
||||||
|
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
||||||
|
outputs = self.conv(outputs)
|
||||||
|
return outputs, input_lengths * self.stride
|
||||||
|
|
||||||
|
|
||||||
|
class PreLookaheadLayer(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.channels = channels
|
||||||
|
self.pre_lookahead_len = pre_lookahead_len
|
||||||
|
self.conv1 = nn.Conv1d(
|
||||||
|
in_channels, channels,
|
||||||
|
kernel_size=pre_lookahead_len + 1,
|
||||||
|
stride=1, padding=0,
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Conv1d(
|
||||||
|
channels, in_channels,
|
||||||
|
kernel_size=3, stride=1, padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
inputs: (batch_size, seq_len, channels)
|
||||||
|
"""
|
||||||
|
outputs = inputs.transpose(1, 2).contiguous()
|
||||||
|
context = context.transpose(1, 2).contiguous()
|
||||||
|
# look ahead
|
||||||
|
if context.size(2) == 0:
|
||||||
|
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
||||||
|
else:
|
||||||
|
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
||||||
|
assert context.size(2) == self.pre_lookahead_len
|
||||||
|
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
|
||||||
|
outputs = F.leaky_relu(self.conv1(outputs))
|
||||||
|
# outputs
|
||||||
|
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
||||||
|
outputs = self.conv2(outputs)
|
||||||
|
outputs = outputs.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
# residual connection
|
||||||
|
outputs = outputs + inputs
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class UpsampleConformerEncoder(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int = 256,
|
||||||
|
attention_heads: int = 4,
|
||||||
|
linear_units: int = 2048,
|
||||||
|
num_blocks: int = 6,
|
||||||
|
dropout_rate: float = 0.1,
|
||||||
|
positional_dropout_rate: float = 0.1,
|
||||||
|
attention_dropout_rate: float = 0.0,
|
||||||
|
input_layer: str = "conv2d",
|
||||||
|
pos_enc_layer_type: str = "rel_pos",
|
||||||
|
normalize_before: bool = True,
|
||||||
|
static_chunk_size: int = 0,
|
||||||
|
use_dynamic_chunk: bool = False,
|
||||||
|
global_cmvn: torch.nn.Module = None,
|
||||||
|
use_dynamic_left_chunk: bool = False,
|
||||||
|
positionwise_conv_kernel_size: int = 1,
|
||||||
|
macaron_style: bool = True,
|
||||||
|
selfattention_layer_type: str = "rel_selfattn",
|
||||||
|
activation_type: str = "swish",
|
||||||
|
use_cnn_module: bool = True,
|
||||||
|
cnn_module_kernel: int = 15,
|
||||||
|
causal: bool = False,
|
||||||
|
cnn_module_norm: str = "batch_norm",
|
||||||
|
key_bias: bool = True,
|
||||||
|
gradient_checkpointing: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_size (int): input dim
|
||||||
|
output_size (int): dimension of attention
|
||||||
|
attention_heads (int): the number of heads of multi head attention
|
||||||
|
linear_units (int): the hidden units number of position-wise feed
|
||||||
|
forward
|
||||||
|
num_blocks (int): the number of decoder blocks
|
||||||
|
dropout_rate (float): dropout rate
|
||||||
|
attention_dropout_rate (float): dropout rate in attention
|
||||||
|
positional_dropout_rate (float): dropout rate after adding
|
||||||
|
positional encoding
|
||||||
|
input_layer (str): input layer type.
|
||||||
|
optional [linear, conv2d, conv2d6, conv2d8]
|
||||||
|
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||||
|
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
||||||
|
normalize_before (bool):
|
||||||
|
True: use layer_norm before each sub-block of a layer.
|
||||||
|
False: use layer_norm after each sub-block of a layer.
|
||||||
|
static_chunk_size (int): chunk size for static chunk training and
|
||||||
|
decoding
|
||||||
|
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
||||||
|
training or not, You can only use fixed chunk(chunk_size > 0)
|
||||||
|
or dyanmic chunk size(use_dynamic_chunk = True)
|
||||||
|
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
||||||
|
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
||||||
|
dynamic chunk training
|
||||||
|
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
||||||
|
gradient_checkpointing: rerunning a forward-pass segment for each
|
||||||
|
checkpointed segment during backward.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._output_size = output_size
|
||||||
|
|
||||||
|
self.global_cmvn = global_cmvn
|
||||||
|
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
||||||
|
input_size,
|
||||||
|
output_size,
|
||||||
|
dropout_rate,
|
||||||
|
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
||||||
|
positional_dropout_rate),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
||||||
|
self.static_chunk_size = static_chunk_size
|
||||||
|
self.use_dynamic_chunk = use_dynamic_chunk
|
||||||
|
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
||||||
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
|
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
||||||
|
# self-attention module definition
|
||||||
|
encoder_selfattn_layer_args = (
|
||||||
|
attention_heads,
|
||||||
|
output_size,
|
||||||
|
attention_dropout_rate,
|
||||||
|
key_bias,
|
||||||
|
)
|
||||||
|
# feed-forward module definition
|
||||||
|
positionwise_layer_args = (
|
||||||
|
output_size,
|
||||||
|
linear_units,
|
||||||
|
dropout_rate,
|
||||||
|
activation,
|
||||||
|
)
|
||||||
|
# convolution module definition
|
||||||
|
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
||||||
|
cnn_module_norm, causal)
|
||||||
|
self.pre_lookahead_layer = PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3)
|
||||||
|
self.encoders = torch.nn.ModuleList([
|
||||||
|
ConformerEncoderLayer(
|
||||||
|
output_size,
|
||||||
|
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
||||||
|
*encoder_selfattn_layer_args),
|
||||||
|
PositionwiseFeedForward(*positionwise_layer_args),
|
||||||
|
PositionwiseFeedForward(
|
||||||
|
*positionwise_layer_args) if macaron_style else None,
|
||||||
|
ConvolutionModule(
|
||||||
|
*convolution_layer_args) if use_cnn_module else None,
|
||||||
|
dropout_rate,
|
||||||
|
normalize_before,
|
||||||
|
) for _ in range(num_blocks)
|
||||||
|
])
|
||||||
|
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
||||||
|
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
||||||
|
input_size,
|
||||||
|
output_size,
|
||||||
|
dropout_rate,
|
||||||
|
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
||||||
|
positional_dropout_rate),
|
||||||
|
)
|
||||||
|
self.up_encoders = torch.nn.ModuleList([
|
||||||
|
ConformerEncoderLayer(
|
||||||
|
output_size,
|
||||||
|
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
||||||
|
*encoder_selfattn_layer_args),
|
||||||
|
PositionwiseFeedForward(*positionwise_layer_args),
|
||||||
|
PositionwiseFeedForward(
|
||||||
|
*positionwise_layer_args) if macaron_style else None,
|
||||||
|
ConvolutionModule(
|
||||||
|
*convolution_layer_args) if use_cnn_module else None,
|
||||||
|
dropout_rate,
|
||||||
|
normalize_before,
|
||||||
|
) for _ in range(4)
|
||||||
|
])
|
||||||
|
|
||||||
|
def output_size(self) -> int:
|
||||||
|
return self._output_size
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
xs: torch.Tensor,
|
||||||
|
xs_lens: torch.Tensor,
|
||||||
|
context: torch.Tensor = torch.zeros(0, 0, 0),
|
||||||
|
decoding_chunk_size: int = 0,
|
||||||
|
num_decoding_left_chunks: int = -1,
|
||||||
|
streaming: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Embed positions in tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs: padded input tensor (B, T, D)
|
||||||
|
xs_lens: input length (B)
|
||||||
|
decoding_chunk_size: decoding chunk size for dynamic chunk
|
||||||
|
0: default for training, use random dynamic chunk.
|
||||||
|
<0: for decoding, use full chunk.
|
||||||
|
>0: for decoding, use fixed chunk size as set.
|
||||||
|
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||||
|
the chunk size is decoding_chunk_size.
|
||||||
|
>=0: use num_decoding_left_chunks
|
||||||
|
<0: use all left chunks
|
||||||
|
Returns:
|
||||||
|
encoder output tensor xs, and subsampled masks
|
||||||
|
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
||||||
|
masks: torch.Tensor batch padding mask after subsample
|
||||||
|
(B, 1, T' ~= T/subsample_rate)
|
||||||
|
NOTE(xcsong):
|
||||||
|
We pass the `__call__` method of the modules instead of `forward` to the
|
||||||
|
checkpointing API because `__call__` attaches all the hooks of the module.
|
||||||
|
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||||
|
"""
|
||||||
|
T = xs.size(1)
|
||||||
|
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
||||||
|
if self.global_cmvn is not None:
|
||||||
|
xs = self.global_cmvn(xs)
|
||||||
|
xs, pos_emb, masks = self.embed(xs, masks)
|
||||||
|
if context.size(1) != 0:
|
||||||
|
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
||||||
|
context_masks = torch.ones(1, 1, context.size(1)).to(masks)
|
||||||
|
context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
|
||||||
|
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||||
|
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
||||||
|
# lookahead + conformer encoder
|
||||||
|
xs = self.pre_lookahead_layer(xs, context=context)
|
||||||
|
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||||
|
|
||||||
|
# upsample + conformer encoder
|
||||||
|
xs = xs.transpose(1, 2).contiguous()
|
||||||
|
xs, xs_lens = self.up_layer(xs, xs_lens)
|
||||||
|
xs = xs.transpose(1, 2).contiguous()
|
||||||
|
T = xs.size(1)
|
||||||
|
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
||||||
|
xs, pos_emb, masks = self.up_embed(xs, masks)
|
||||||
|
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||||
|
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
|
||||||
|
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||||
|
|
||||||
|
if self.normalize_before:
|
||||||
|
xs = self.after_norm(xs)
|
||||||
|
# Here we assume the mask is not changed in encoder layers, so just
|
||||||
|
# return the masks before encoder layers, and the masks will be used
|
||||||
|
# for cross attention with decoder later
|
||||||
|
return xs, masks
|
||||||
|
|
||||||
|
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
mask_pad: torch.Tensor) -> torch.Tensor:
|
||||||
|
for layer in self.encoders:
|
||||||
|
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||||
|
return xs
|
||||||
|
|
||||||
|
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
||||||
|
pos_emb: torch.Tensor,
|
||||||
|
mask_pad: torch.Tensor) -> torch.Tensor:
|
||||||
|
for layer in self.up_encoders:
|
||||||
|
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||||
|
return xs
|
||||||
@@ -32,6 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
|
|||||||
RelPositionMultiHeadedAttention)
|
RelPositionMultiHeadedAttention)
|
||||||
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
||||||
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
||||||
|
from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
|
||||||
|
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
|
||||||
|
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
|
||||||
|
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
|
||||||
|
|
||||||
|
|
||||||
COSYVOICE_ACTIVATION_CLASSES = {
|
COSYVOICE_ACTIVATION_CLASSES = {
|
||||||
@@ -68,3 +72,14 @@ COSYVOICE_ATTENTION_CLASSES = {
|
|||||||
"selfattn": MultiHeadedAttention,
|
"selfattn": MultiHeadedAttention,
|
||||||
"rel_selfattn": RelPositionMultiHeadedAttention,
|
"rel_selfattn": RelPositionMultiHeadedAttention,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_type(configs):
|
||||||
|
# NOTE CosyVoice2Model inherits CosyVoiceModel
|
||||||
|
if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||||
|
return CosyVoiceModel
|
||||||
|
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||||
|
return CosyVoice2Model
|
||||||
|
if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
|
||||||
|
return CosyVoice3Model
|
||||||
|
raise TypeError('No valid model type found!')
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||||
|
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -15,12 +16,42 @@
|
|||||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||||
"""Unility functions for Transformer."""
|
"""Unility functions for Transformer."""
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import random
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
IGNORE_ID = -1
|
IGNORE_ID = -1
|
||||||
|
|
||||||
|
instruct_list = ["You are a helpful assistant. 请用广东话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用东北话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用甘肃话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用贵州话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用河南话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用湖北话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用湖南话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用江西话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用闽南话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用宁夏话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用山西话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用陕西话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用山东话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用上海话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用四川话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用天津话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用云南话表达。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. Please say a sentence as loudly as possible.<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. Please say a sentence in a very soft voice.<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用尽可能慢地语速说一句话。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请非常开心地说一句话。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请非常伤心地说一句话。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 请非常生气地说一句话。<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 我想体验一下小猪佩奇风格,可以吗?<|endofprompt|>",
|
||||||
|
"You are a helpful assistant. 你可以尝试用机器人的方式解答吗?<|endofprompt|>"]
|
||||||
|
|
||||||
|
|
||||||
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
||||||
"""Perform padding for the list of tensors.
|
"""Perform padding for the list of tensors.
|
||||||
@@ -101,3 +132,83 @@ def init_weights(m, mean=0.0, std=0.01):
|
|||||||
classname = m.__class__.__name__
|
classname = m.__class__.__name__
|
||||||
if classname.find("Conv") != -1:
|
if classname.find("Conv") != -1:
|
||||||
m.weight.data.normal_(mean, std)
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
|
||||||
|
# Repetition Aware Sampling in VALL-E 2
|
||||||
|
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
|
||||||
|
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
||||||
|
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
|
||||||
|
if rep_num >= win_size * tau_r:
|
||||||
|
weighted_scores[top_ids] = -float('inf')
|
||||||
|
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
|
||||||
|
return top_ids
|
||||||
|
|
||||||
|
|
||||||
|
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
||||||
|
prob, indices = [], []
|
||||||
|
cum_prob = 0.0
|
||||||
|
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
|
||||||
|
for i in range(len(sorted_idx)):
|
||||||
|
# sampling both top-p and numbers.
|
||||||
|
if cum_prob < top_p and len(prob) < top_k:
|
||||||
|
cum_prob += sorted_value[i]
|
||||||
|
prob.append(sorted_value[i])
|
||||||
|
indices.append(sorted_idx[i])
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
prob = torch.tensor(prob).to(weighted_scores)
|
||||||
|
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
|
||||||
|
top_ids = indices[prob.multinomial(1, replacement=True)].item()
|
||||||
|
return top_ids
|
||||||
|
|
||||||
|
|
||||||
|
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
||||||
|
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True).item()
|
||||||
|
return top_ids
|
||||||
|
|
||||||
|
|
||||||
|
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
||||||
|
device = fade_in_mel.device
|
||||||
|
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
||||||
|
mel_overlap_len = int(window.shape[0] / 2)
|
||||||
|
if fade_in_mel.device == torch.device('cpu'):
|
||||||
|
fade_in_mel = fade_in_mel.clone()
|
||||||
|
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
||||||
|
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||||
|
return fade_in_mel.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def set_all_random_seed(seed):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
assert mask.dtype == torch.bool
|
||||||
|
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
||||||
|
mask = mask.to(dtype)
|
||||||
|
# attention mask bias
|
||||||
|
# NOTE(Mddct): torch.finfo jit issues
|
||||||
|
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
||||||
|
mask = (1.0 - mask) * -1.0e+10
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class TrtContextWrapper:
|
||||||
|
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||||
|
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||||
|
self.trt_engine = trt_engine
|
||||||
|
for _ in range(trt_concurrent):
|
||||||
|
trt_context = trt_engine.create_execution_context()
|
||||||
|
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
|
||||||
|
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||||
|
self.trt_context_pool.put([trt_context, trt_stream])
|
||||||
|
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||||
|
|
||||||
|
def acquire_estimator(self):
|
||||||
|
return self.trt_context_pool.get(), self.trt_engine
|
||||||
|
|
||||||
|
def release_estimator(self, context, stream):
|
||||||
|
self.trt_context_pool.put([context, stream])
|
||||||
|
|||||||
@@ -25,13 +25,68 @@ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, l
|
|||||||
|
|
||||||
class Executor:
|
class Executor:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
|
||||||
|
self.gan = gan
|
||||||
|
self.ref_model = ref_model
|
||||||
|
self.dpo_loss = dpo_loss
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.rank = int(os.environ.get('RANK', 0))
|
self.rank = int(os.environ.get('RANK', 0))
|
||||||
self.device = torch.device('cuda:{}'.format(self.rank))
|
self.device = torch.device('cuda:{}'.format(self.rank))
|
||||||
|
|
||||||
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join):
|
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
|
||||||
|
''' Train one epoch
|
||||||
|
'''
|
||||||
|
|
||||||
|
lr = optimizer.param_groups[0]['lr']
|
||||||
|
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
|
||||||
|
logging.info('using accumulate grad, new batch size is {} times'
|
||||||
|
' larger than before'.format(info_dict['accum_grad']))
|
||||||
|
# A context manager to be used in conjunction with an instance of
|
||||||
|
# torch.nn.parallel.DistributedDataParallel to be able to train
|
||||||
|
# with uneven inputs across participating processes.
|
||||||
|
model.train()
|
||||||
|
if self.ref_model is not None:
|
||||||
|
self.ref_model.eval()
|
||||||
|
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
||||||
|
with model_context():
|
||||||
|
for batch_idx, batch_dict in enumerate(train_data_loader):
|
||||||
|
info_dict["tag"] = "TRAIN"
|
||||||
|
info_dict["step"] = self.step
|
||||||
|
info_dict["epoch"] = self.epoch
|
||||||
|
info_dict["batch_idx"] = batch_idx
|
||||||
|
if cosyvoice_join(group_join, info_dict):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Disable gradient synchronizations across DDP processes.
|
||||||
|
# Within this context, gradients will be accumulated on module
|
||||||
|
# variables, which will later be synchronized.
|
||||||
|
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
||||||
|
context = model.no_sync
|
||||||
|
# Used for single gpu training and DDP gradient synchronization
|
||||||
|
# processes.
|
||||||
|
else:
|
||||||
|
context = nullcontext
|
||||||
|
|
||||||
|
with context():
|
||||||
|
info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
|
||||||
|
info_dict = batch_backward(model, scaler, info_dict)
|
||||||
|
|
||||||
|
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
||||||
|
log_per_step(writer, info_dict)
|
||||||
|
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
||||||
|
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
||||||
|
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||||
|
dist.barrier()
|
||||||
|
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
||||||
|
model.train()
|
||||||
|
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||||
|
self.step += 1
|
||||||
|
dist.barrier()
|
||||||
|
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
||||||
|
|
||||||
|
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||||
|
writer, info_dict, scaler, group_join):
|
||||||
''' Train one epoch
|
''' Train one epoch
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@@ -64,13 +119,22 @@ class Executor:
|
|||||||
context = nullcontext
|
context = nullcontext
|
||||||
|
|
||||||
with context():
|
with context():
|
||||||
info_dict = batch_forward(model, batch_dict, info_dict)
|
batch_dict['turn'] = 'discriminator'
|
||||||
info_dict = batch_backward(model, info_dict)
|
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||||
|
info_dict = batch_backward(model, scaler, info_dict)
|
||||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
|
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
log_per_step(writer, info_dict)
|
||||||
|
with context():
|
||||||
|
batch_dict['turn'] = 'generator'
|
||||||
|
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||||
|
info_dict = batch_backward(model, scaler, info_dict)
|
||||||
|
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
||||||
|
optimizer_d.zero_grad()
|
||||||
log_per_step(writer, info_dict)
|
log_per_step(writer, info_dict)
|
||||||
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
||||||
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
||||||
|
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
||||||
model.train()
|
model.train()
|
||||||
@@ -95,12 +159,14 @@ class Executor:
|
|||||||
num_utts = len(batch_dict["utts"])
|
num_utts = len(batch_dict["utts"])
|
||||||
total_num_utts += num_utts
|
total_num_utts += num_utts
|
||||||
|
|
||||||
info_dict = batch_forward(model, batch_dict, info_dict)
|
if self.gan is True:
|
||||||
|
batch_dict['turn'] = 'generator'
|
||||||
|
info_dict = batch_forward(model, batch_dict, None, info_dict)
|
||||||
|
|
||||||
for k, v in info_dict['loss_dict'].items():
|
for k, v in info_dict['loss_dict'].items():
|
||||||
if k not in total_loss_dict:
|
if k not in total_loss_dict:
|
||||||
total_loss_dict[k] = []
|
total_loss_dict[k] = []
|
||||||
total_loss_dict[k].append(v.item() * num_utts)
|
total_loss_dict[k].append(v.mean().item() * num_utts)
|
||||||
log_per_step(None, info_dict)
|
log_per_step(None, info_dict)
|
||||||
for k, v in total_loss_dict.items():
|
for k, v in total_loss_dict.items():
|
||||||
total_loss_dict[k] = sum(v) / total_num_utts
|
total_loss_dict[k] = sum(v) / total_num_utts
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
# 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
|
||||||
|
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -13,8 +14,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
import json
|
import json
|
||||||
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import logging
|
||||||
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||||
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
|
||||||
|
|
||||||
def read_lists(list_file):
|
def read_lists(list_file):
|
||||||
@@ -24,6 +31,7 @@ def read_lists(list_file):
|
|||||||
lists.append(line.strip())
|
lists.append(line.strip())
|
||||||
return lists
|
return lists
|
||||||
|
|
||||||
|
|
||||||
def read_json_lists(list_file):
|
def read_json_lists(list_file):
|
||||||
lists = read_lists(list_file)
|
lists = read_lists(list_file)
|
||||||
results = {}
|
results = {}
|
||||||
@@ -32,22 +40,79 @@ def read_json_lists(list_file):
|
|||||||
results.update(json.load(fin))
|
results.update(json.load(fin))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def load_wav(wav, target_sr):
|
|
||||||
speech, sample_rate = torchaudio.load(wav)
|
def load_wav(wav, target_sr, min_sr=16000):
|
||||||
|
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
|
||||||
speech = speech.mean(dim=0, keepdim=True)
|
speech = speech.mean(dim=0, keepdim=True)
|
||||||
if sample_rate != target_sr:
|
if sample_rate != target_sr:
|
||||||
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
||||||
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
||||||
return speech
|
return speech
|
||||||
|
|
||||||
def speed_change(waveform, sample_rate, speed_factor: str):
|
|
||||||
effects = [
|
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||||
["tempo", speed_factor], # speed_factor
|
import tensorrt as trt
|
||||||
["rate", f"{sample_rate}"]
|
logging.info("Converting onnx to trt...")
|
||||||
]
|
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||||
augmented_waveform, new_sample_rate = torchaudio.sox_effects.apply_effects_tensor(
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
waveform,
|
builder = trt.Builder(logger)
|
||||||
sample_rate,
|
network = builder.create_network(network_flags)
|
||||||
effects
|
parser = trt.OnnxParser(network, logger)
|
||||||
)
|
config = builder.create_builder_config()
|
||||||
return augmented_waveform, new_sample_rate
|
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||||
|
if fp16:
|
||||||
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
|
profile = builder.create_optimization_profile()
|
||||||
|
# load onnx model
|
||||||
|
with open(onnx_model, "rb") as f:
|
||||||
|
if not parser.parse(f.read()):
|
||||||
|
for error in range(parser.num_errors):
|
||||||
|
print(parser.get_error(error))
|
||||||
|
raise ValueError('failed to parse {}'.format(onnx_model))
|
||||||
|
# set input shapes
|
||||||
|
for i in range(len(trt_kwargs['input_names'])):
|
||||||
|
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
||||||
|
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
||||||
|
# set input and output data type
|
||||||
|
for i in range(network.num_inputs):
|
||||||
|
input_tensor = network.get_input(i)
|
||||||
|
input_tensor.dtype = tensor_dtype
|
||||||
|
for i in range(network.num_outputs):
|
||||||
|
output_tensor = network.get_output(i)
|
||||||
|
output_tensor.dtype = tensor_dtype
|
||||||
|
config.add_optimization_profile(profile)
|
||||||
|
engine_bytes = builder.build_serialized_network(network, config)
|
||||||
|
# save trt engine
|
||||||
|
with open(trt_model, "wb") as f:
|
||||||
|
f.write(engine_bytes)
|
||||||
|
logging.info("Succesfully convert onnx to trt...")
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE do not support bistream inference as only speech token embedding/head is kept
|
||||||
|
def export_cosyvoice2_vllm(model, model_path, device):
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
# lm_head
|
||||||
|
use_bias = True if model.llm_decoder.bias is not None else False
|
||||||
|
model.llm.model.lm_head = model.llm_decoder
|
||||||
|
# embed_tokens
|
||||||
|
embed_tokens = model.llm.model.model.embed_tokens
|
||||||
|
model.llm.model.set_input_embeddings(model.speech_embedding)
|
||||||
|
model.llm.model.to(device)
|
||||||
|
model.llm.model.to(dtype)
|
||||||
|
tmp_vocab_size = model.llm.model.config.vocab_size
|
||||||
|
tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
|
||||||
|
del model.llm.model.generation_config.eos_token_id
|
||||||
|
del model.llm.model.config.bos_token_id
|
||||||
|
del model.llm.model.config.eos_token_id
|
||||||
|
model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
|
||||||
|
model.llm.model.config.tie_word_embeddings = False
|
||||||
|
model.llm.model.config.use_bias = use_bias
|
||||||
|
model.llm.model.save_pretrained(model_path)
|
||||||
|
if use_bias is True:
|
||||||
|
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||||
|
model.llm.model.config.vocab_size = tmp_vocab_size
|
||||||
|
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||||
|
model.llm.model.set_input_embeddings(embed_tokens)
|
||||||
|
|||||||
@@ -13,8 +13,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import regex
|
||||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
||||||
|
|
||||||
|
|
||||||
# whether contain chinese character
|
# whether contain chinese character
|
||||||
def contains_chinese(text):
|
def contains_chinese(text):
|
||||||
return bool(chinese_char_pattern.search(text))
|
return bool(chinese_char_pattern.search(text))
|
||||||
@@ -79,6 +81,13 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
|
|||||||
pounc = ['.', '?', '!', ';', ':']
|
pounc = ['.', '?', '!', ';', ':']
|
||||||
if comma_split:
|
if comma_split:
|
||||||
pounc.extend([',', ','])
|
pounc.extend([',', ','])
|
||||||
|
|
||||||
|
if text[-1] not in pounc:
|
||||||
|
if lang == "zh":
|
||||||
|
text += "。"
|
||||||
|
else:
|
||||||
|
text += "."
|
||||||
|
|
||||||
st = 0
|
st = 0
|
||||||
utts = []
|
utts = []
|
||||||
for i, c in enumerate(text):
|
for i, c in enumerate(text):
|
||||||
@@ -91,11 +100,7 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
|
|||||||
st = i + 2
|
st = i + 2
|
||||||
else:
|
else:
|
||||||
st = i + 1
|
st = i + 1
|
||||||
if len(utts) == 0:
|
|
||||||
if lang == "zh":
|
|
||||||
utts.append(text + '。')
|
|
||||||
else:
|
|
||||||
utts.append(text + '.')
|
|
||||||
final_utts = []
|
final_utts = []
|
||||||
cur_utt = ""
|
cur_utt = ""
|
||||||
for utt in utts:
|
for utt in utts:
|
||||||
@@ -123,3 +128,9 @@ def replace_blank(text: str):
|
|||||||
else:
|
else:
|
||||||
out_str.append(c)
|
out_str.append(c)
|
||||||
return "".join(out_str)
|
return "".join(out_str)
|
||||||
|
|
||||||
|
|
||||||
|
def is_only_punctuation(text):
|
||||||
|
# Regular expression: Match strings that consist only of punctuation marks or are empty.
|
||||||
|
punctuation_pattern = r'^[\p{P}\p{S}]*$'
|
||||||
|
return bool(regex.fullmatch(punctuation_pattern, text))
|
||||||
|
|||||||
57
cosyvoice/utils/losses.py
Normal file
57
cosyvoice/utils/losses.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||||
|
m_DG = torch.median((dr - dg))
|
||||||
|
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
|
||||||
|
loss += tau - F.relu(tau - L_rel)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def mel_loss(real_speech, generated_speech, mel_transforms):
|
||||||
|
loss = 0
|
||||||
|
for transform in mel_transforms:
|
||||||
|
mel_r = transform(real_speech)
|
||||||
|
mel_g = transform(generated_speech)
|
||||||
|
loss += F.l1_loss(mel_g, mel_r)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class DPOLoss(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
DPO Loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.beta = beta
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
self.ipo = ipo
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
policy_chosen_logps: torch.Tensor,
|
||||||
|
policy_rejected_logps: torch.Tensor,
|
||||||
|
reference_chosen_logps: torch.Tensor,
|
||||||
|
reference_rejected_logps: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||||
|
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||||
|
logits = pi_logratios - ref_logratios
|
||||||
|
if self.ipo:
|
||||||
|
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
|
||||||
|
else:
|
||||||
|
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
|
||||||
|
losses = (
|
||||||
|
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||||
|
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||||
|
)
|
||||||
|
loss = losses.mean()
|
||||||
|
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
|
||||||
|
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
|
||||||
|
|
||||||
|
return loss, chosen_rewards, rejected_rewards
|
||||||
@@ -86,7 +86,7 @@ def subsequent_mask(
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def subsequent_chunk_mask(
|
def subsequent_chunk_mask_deprecated(
|
||||||
size: int,
|
size: int,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
num_left_chunks: int = -1,
|
num_left_chunks: int = -1,
|
||||||
@@ -124,6 +124,40 @@ def subsequent_chunk_mask(
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def subsequent_chunk_mask(
|
||||||
|
size: int,
|
||||||
|
chunk_size: int,
|
||||||
|
num_left_chunks: int = -1,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||||
|
this is for streaming encoder
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (int): size of mask
|
||||||
|
chunk_size (int): size of chunk
|
||||||
|
num_left_chunks (int): number of left chunks
|
||||||
|
<0: use full chunk
|
||||||
|
>=0: use num_left_chunks
|
||||||
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: mask
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> subsequent_chunk_mask(4, 2)
|
||||||
|
[[1, 1, 0, 0],
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1]]
|
||||||
|
"""
|
||||||
|
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
||||||
|
pos_idx = torch.arange(size, device=device)
|
||||||
|
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
||||||
|
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def add_optional_chunk_mask(xs: torch.Tensor,
|
def add_optional_chunk_mask(xs: torch.Tensor,
|
||||||
masks: torch.Tensor,
|
masks: torch.Tensor,
|
||||||
use_dynamic_chunk: bool,
|
use_dynamic_chunk: bool,
|
||||||
@@ -195,6 +229,10 @@ def add_optional_chunk_mask(xs: torch.Tensor,
|
|||||||
chunk_masks = masks & chunk_masks # (B, L, L)
|
chunk_masks = masks & chunk_masks # (B, L, L)
|
||||||
else:
|
else:
|
||||||
chunk_masks = masks
|
chunk_masks = masks
|
||||||
|
assert chunk_masks.dtype == torch.bool
|
||||||
|
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
||||||
|
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
||||||
|
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
||||||
return chunk_masks
|
return chunk_masks
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
54
cosyvoice/utils/onnx.py
Normal file
54
cosyvoice/utils/onnx.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import onnxruntime
|
||||||
|
import torch, random
|
||||||
|
import os
|
||||||
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechTokenExtractor():
|
||||||
|
def __init__(self, model_path):
|
||||||
|
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
option = onnxruntime.SessionOptions()
|
||||||
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
option.intra_op_num_threads = 1
|
||||||
|
self.speech_tokenizer_session = onnxruntime.InferenceSession(model_path,
|
||||||
|
sess_options=option,
|
||||||
|
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
|
||||||
|
|
||||||
|
def inference(self, feat, feat_lengths, device):
|
||||||
|
speech_token = self.speech_tokenizer_session.run(None,
|
||||||
|
{self.speech_tokenizer_session.get_inputs()[0].name:
|
||||||
|
feat.transpose(1, 2).detach().cpu().numpy(),
|
||||||
|
self.speech_tokenizer_session.get_inputs()[1].name:
|
||||||
|
feat_lengths.detach().cpu().numpy()})[0]
|
||||||
|
return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingExtractor():
|
||||||
|
def __init__(self, model_path):
|
||||||
|
option = onnxruntime.SessionOptions()
|
||||||
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
option.intra_op_num_threads = 1
|
||||||
|
self.max_len = 10 * 16000
|
||||||
|
self.campplus_session = onnxruntime.InferenceSession(model_path,
|
||||||
|
sess_options=option,
|
||||||
|
providers=["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
def inference(self, speech):
|
||||||
|
if speech.shape[1] > self.max_len:
|
||||||
|
start_index = random.randint(0, speech.shape[1] - self.max_len)
|
||||||
|
speech = speech[:, start_index: start_index + self.max_len]
|
||||||
|
feat = kaldi.fbank(speech,
|
||||||
|
num_mel_bins=80,
|
||||||
|
dither=0,
|
||||||
|
sample_frequency=16000)
|
||||||
|
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||||
|
embedding = self.campplus_session.run(None,
|
||||||
|
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||||
|
return torch.tensor(embedding).to(speech.device)
|
||||||
|
|
||||||
|
# singleton mode, only initialized once
|
||||||
|
onnx_path = os.environ.get('onnx_path')
|
||||||
|
if onnx_path is not None:
|
||||||
|
embedding_extractor, online_feature = EmbeddingExtractor(model_path=os.path.join(onnx_path, 'campplus.onnx')), True
|
||||||
|
else:
|
||||||
|
embedding_extractor, online_feature = None, False
|
||||||
@@ -567,8 +567,7 @@ class NoamAnnealing(_LRScheduler):
|
|||||||
min_lr=0.0,
|
min_lr=0.0,
|
||||||
last_epoch=-1):
|
last_epoch=-1):
|
||||||
self._normalize = d_model**(-0.5)
|
self._normalize = d_model**(-0.5)
|
||||||
assert not (warmup_steps is not None
|
assert not (warmup_steps is not None and warmup_ratio is not None), \
|
||||||
and warmup_ratio is not None), \
|
|
||||||
"Either use particular number of step or ratio"
|
"Either use particular number of step or ratio"
|
||||||
assert warmup_ratio is None or max_steps is not None, \
|
assert warmup_ratio is None or max_steps is not None, \
|
||||||
"If there is a ratio, there should be a total steps"
|
"If there is a ratio, there should be a total steps"
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from contextlib import nullcontext
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
@@ -51,9 +50,10 @@ def init_distributed(args):
|
|||||||
return world_size, local_rank, rank
|
return world_size, local_rank, rank
|
||||||
|
|
||||||
|
|
||||||
def init_dataset_and_dataloader(args, configs):
|
def init_dataset_and_dataloader(args, configs, gan, dpo):
|
||||||
train_dataset = Dataset(args.train_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=True, partition=True)
|
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
|
||||||
cv_dataset = Dataset(args.cv_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=False, partition=False)
|
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
|
||||||
|
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='dev', gan=gan, dpo=dpo, shuffle=False, partition=False)
|
||||||
|
|
||||||
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
|
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
|
||||||
train_data_loader = DataLoader(train_dataset,
|
train_data_loader = DataLoader(train_dataset,
|
||||||
@@ -69,10 +69,9 @@ def init_dataset_and_dataloader(args, configs):
|
|||||||
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
|
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def check_modify_and_save_config(args, configs):
|
def check_modify_and_save_config(args, configs):
|
||||||
if args.train_engine == "torch_ddp":
|
if args.train_engine == "torch_ddp":
|
||||||
configs['train_conf']["dtype"] = 'fp32'
|
configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32'
|
||||||
else:
|
else:
|
||||||
with open(args.deepspeed_config, 'r') as fin:
|
with open(args.deepspeed_config, 'r') as fin:
|
||||||
ds_configs = json.load(fin)
|
ds_configs = json.load(fin)
|
||||||
@@ -84,7 +83,8 @@ def check_modify_and_save_config(args, configs):
|
|||||||
configs['train_conf']["dtype"] = "fp32"
|
configs['train_conf']["dtype"] = "fp32"
|
||||||
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
|
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
|
||||||
# if use deepspeed, override ddp config
|
# if use deepspeed, override ddp config
|
||||||
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
|
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
|
||||||
|
configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
|
||||||
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
|
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
|
||||||
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
|
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
|
||||||
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
|
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
|
||||||
@@ -108,38 +108,80 @@ def wrap_cuda_model(args, model):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def init_optimizer_and_scheduler(args, configs, model):
|
def init_optimizer_and_scheduler(args, configs, model, gan):
|
||||||
if configs['train_conf']['optim'] == 'adam':
|
if gan is False:
|
||||||
optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
|
if configs['train_conf']['optim'] == 'adam':
|
||||||
elif configs['train_conf']['optim'] == 'adamw':
|
optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
|
||||||
optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
|
elif configs['train_conf']['optim'] == 'adamw':
|
||||||
|
optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||||
|
|
||||||
|
if configs['train_conf']['scheduler'] == 'warmuplr':
|
||||||
|
scheduler_type = WarmupLR
|
||||||
|
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||||
|
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
||||||
|
scheduler_type = NoamHoldAnnealing
|
||||||
|
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||||
|
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||||
|
scheduler_type = ConstantLR
|
||||||
|
scheduler = ConstantLR(optimizer)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||||
|
|
||||||
|
# use deepspeed optimizer for speedup
|
||||||
|
if args.train_engine == "deepspeed":
|
||||||
|
def scheduler(opt):
|
||||||
|
return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
|
||||||
|
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||||
|
args=args,
|
||||||
|
model=model,
|
||||||
|
optimizer=None,
|
||||||
|
lr_scheduler=scheduler,
|
||||||
|
model_parameters=model.parameters())
|
||||||
|
|
||||||
|
optimizer_d, scheduler_d = None, None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
|
||||||
|
if configs['train_conf']['optim'] == 'adam':
|
||||||
|
optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
||||||
|
elif configs['train_conf']['optim'] == 'adamw':
|
||||||
|
optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||||
|
|
||||||
if configs['train_conf']['scheduler'] == 'warmuplr':
|
if configs['train_conf']['scheduler'] == 'warmuplr':
|
||||||
scheduler_type = WarmupLR
|
scheduler_type = WarmupLR
|
||||||
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||||
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
||||||
scheduler_type = NoamHoldAnnealing
|
scheduler_type = NoamHoldAnnealing
|
||||||
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||||
scheduler_type = ConstantLR
|
scheduler_type = ConstantLR
|
||||||
scheduler = ConstantLR(optimizer)
|
scheduler = ConstantLR(optimizer)
|
||||||
else:
|
else:
|
||||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||||
|
|
||||||
# use deepspeed optimizer for speedup
|
if configs['train_conf']['optim_d'] == 'adam':
|
||||||
if args.train_engine == "deepspeed":
|
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
|
||||||
def scheduler(opt):
|
elif configs['train_conf']['optim_d'] == 'adamw':
|
||||||
return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
|
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
|
||||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
else:
|
||||||
args=args,
|
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||||
model=model,
|
|
||||||
optimizer=None,
|
|
||||||
lr_scheduler=scheduler,
|
|
||||||
model_parameters=model.parameters())
|
|
||||||
|
|
||||||
return model, optimizer, scheduler
|
if configs['train_conf']['scheduler_d'] == 'warmuplr':
|
||||||
|
scheduler_type = WarmupLR
|
||||||
|
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_d'])
|
||||||
|
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
|
||||||
|
scheduler_type = NoamHoldAnnealing
|
||||||
|
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_d'])
|
||||||
|
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||||
|
scheduler_type = ConstantLR
|
||||||
|
scheduler_d = ConstantLR(optimizer_d)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||||
|
return model, optimizer, scheduler, optimizer_d, scheduler_d
|
||||||
|
|
||||||
|
|
||||||
def init_summarywriter(args):
|
def init_summarywriter(args):
|
||||||
@@ -157,7 +199,7 @@ def save_model(model, model_name, info_dict):
|
|||||||
|
|
||||||
if info_dict["train_engine"] == "torch_ddp":
|
if info_dict["train_engine"] == "torch_ddp":
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
torch.save(model.module.state_dict(), save_model_path)
|
torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.save_checkpoint(save_dir=model_dir,
|
model.save_checkpoint(save_dir=model_dir,
|
||||||
@@ -193,7 +235,7 @@ def cosyvoice_join(group_join, info_dict):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def batch_forward(model, batch, info_dict):
|
def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
|
||||||
device = int(os.environ.get('LOCAL_RANK', 0))
|
device = int(os.environ.get('LOCAL_RANK', 0))
|
||||||
|
|
||||||
dtype = info_dict["dtype"]
|
dtype = info_dict["dtype"]
|
||||||
@@ -205,36 +247,72 @@ def batch_forward(model, batch, info_dict):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
if info_dict['train_engine'] == 'torch_ddp':
|
if info_dict['train_engine'] == 'torch_ddp':
|
||||||
autocast = nullcontext()
|
autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
|
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
|
||||||
|
|
||||||
with autocast:
|
with autocast:
|
||||||
info_dict['loss_dict'] = model(batch, device)
|
info_dict['loss_dict'] = model(batch, device)
|
||||||
|
if ref_model is not None and dpo_loss is not None:
|
||||||
|
chosen_logps = info_dict['loss_dict']["chosen_logps"]
|
||||||
|
rejected_logps = info_dict['loss_dict']["rejected_logps"]
|
||||||
|
sft_loss = info_dict['loss_dict']['loss']
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_loss_dict = ref_model(batch, device)
|
||||||
|
reference_chosen_logps = ref_loss_dict["chosen_logps"]
|
||||||
|
reference_rejected_logps = ref_loss_dict["rejected_logps"]
|
||||||
|
preference_loss, chosen_reward, reject_reward = dpo_loss(
|
||||||
|
chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||||
|
)
|
||||||
|
dpo_acc = (chosen_reward > reject_reward).float().mean()
|
||||||
|
info_dict['loss_dict']["loss"] = preference_loss + sft_loss
|
||||||
|
info_dict['loss_dict']["sft_loss"] = sft_loss
|
||||||
|
info_dict['loss_dict']["dpo_loss"] = preference_loss
|
||||||
|
info_dict['loss_dict']["dpo_acc"] = dpo_acc
|
||||||
|
info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean()
|
||||||
|
info_dict['loss_dict']["reject_reward"] = reject_reward.mean()
|
||||||
return info_dict
|
return info_dict
|
||||||
|
|
||||||
|
|
||||||
def batch_backward(model, info_dict):
|
def batch_backward(model, scaler, info_dict):
|
||||||
if info_dict["train_engine"] == "deepspeed":
|
if info_dict["train_engine"] == "deepspeed":
|
||||||
scaled_loss = model.backward(info_dict['loss_dict']['loss'])
|
scaled_loss = model.backward(info_dict['loss_dict']['loss'])
|
||||||
else:
|
else:
|
||||||
scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
|
scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
|
||||||
scaled_loss.backward()
|
if scaler is not None:
|
||||||
|
scaler.scale(scaled_loss).backward()
|
||||||
|
else:
|
||||||
|
scaled_loss.backward()
|
||||||
|
|
||||||
info_dict['loss_dict']['loss'] = scaled_loss
|
info_dict['loss_dict']['loss'] = scaled_loss
|
||||||
return info_dict
|
return info_dict
|
||||||
|
|
||||||
|
|
||||||
def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
|
def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
|
||||||
grad_norm = 0.0
|
grad_norm = 0.0
|
||||||
if info_dict['train_engine'] == "deepspeed":
|
if info_dict['train_engine'] == "deepspeed":
|
||||||
info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
|
info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
|
||||||
model.step()
|
model.step()
|
||||||
grad_norm = model.get_global_grad_norm()
|
grad_norm = model.get_global_grad_norm()
|
||||||
elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
|
elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
|
||||||
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
# Use mixed precision training
|
||||||
if torch.isfinite(grad_norm):
|
if scaler is not None:
|
||||||
optimizer.step()
|
scaler.unscale_(optimizer)
|
||||||
|
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
||||||
|
# We don't check grad here since that if the gradient
|
||||||
|
# has inf/nan values, scaler.step will skip
|
||||||
|
# optimizer.step().
|
||||||
|
if torch.isfinite(grad_norm):
|
||||||
|
scaler.step(optimizer)
|
||||||
|
else:
|
||||||
|
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
||||||
|
if torch.isfinite(grad_norm):
|
||||||
|
optimizer.step()
|
||||||
|
else:
|
||||||
|
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
info_dict["lr"] = optimizer.param_groups[0]['lr']
|
info_dict["lr"] = optimizer.param_groups[0]['lr']
|
||||||
@@ -280,7 +358,7 @@ def log_per_save(writer, info_dict):
|
|||||||
rank = int(os.environ.get('RANK', 0))
|
rank = int(os.environ.get('RANK', 0))
|
||||||
logging.info(
|
logging.info(
|
||||||
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
|
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
|
||||||
epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
|
epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()])))
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
for k in ['epoch', 'lr']:
|
for k in ['epoch', 'lr']:
|
||||||
|
|||||||
116
cosyvoice/vllm/cosyvoice2.py
Normal file
116
cosyvoice/vllm/cosyvoice2.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
||||||
|
# Copyright 2024 The Qwen team.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||||
|
from typing import Optional
|
||||||
|
from packaging.version import parse as vparse
|
||||||
|
import vllm
|
||||||
|
|
||||||
|
# vLLM-0.11.0+ only support V1 engine
|
||||||
|
VLLM_V1_ENGINE_ONLY: bool = vparse(vllm.__version__) >= vparse("0.11.0")
|
||||||
|
if VLLM_V1_ENGINE_ONLY:
|
||||||
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
from vllm.model_executor.models.qwen2 import *
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "lm_head"))
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: Optional[SamplingMetadata] = None,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
if VLLM_V1_ENGINE_ONLY:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
self.lm_head.bias)
|
||||||
|
else:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata, self.lm_head.bias)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights)
|
||||||
51
docker/Dockerfile
Normal file
51
docker/Dockerfile
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
|
||||||
|
|
||||||
|
ARG VENV_NAME="cosyvoice"
|
||||||
|
ENV VENV=$VENV_NAME
|
||||||
|
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
SHELL ["/bin/bash", "--login", "-c"]
|
||||||
|
|
||||||
|
RUN apt-get update -y --fix-missing
|
||||||
|
RUN apt-get install -y git build-essential curl wget ffmpeg unzip git git-lfs sox libsox-dev && \
|
||||||
|
apt-get clean && \
|
||||||
|
git lfs install
|
||||||
|
|
||||||
|
# ==================================================================
|
||||||
|
# conda install and conda forge channel as default
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Install miniforge
|
||||||
|
RUN wget --quiet https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O ~/miniforge.sh && \
|
||||||
|
/bin/bash ~/miniforge.sh -b -p /opt/conda && \
|
||||||
|
rm ~/miniforge.sh && \
|
||||||
|
ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
|
||||||
|
echo "source /opt/conda/etc/profile.d/conda.sh" >> /opt/nvidia/entrypoint.d/100.conda.sh && \
|
||||||
|
echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
|
||||||
|
echo "conda activate ${VENV}" >> /opt/nvidia/entrypoint.d/110.conda_default_env.sh && \
|
||||||
|
echo "conda activate ${VENV}" >> $HOME/.bashrc
|
||||||
|
|
||||||
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
|
|
||||||
|
RUN conda config --add channels conda-forge && \
|
||||||
|
conda config --set channel_priority strict
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# ~conda
|
||||||
|
# ==================================================================
|
||||||
|
|
||||||
|
RUN conda create -y -n ${VENV} python=3.10
|
||||||
|
ENV CONDA_DEFAULT_ENV=${VENV}
|
||||||
|
ENV PATH /opt/conda/bin:/opt/conda/envs/${VENV}/bin:$PATH
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
ENV PYTHONPATH="${PYTHONPATH}:/workspace/CosyVoice:/workspace/CosyVoice/third_party/Matcha-TTS"
|
||||||
|
|
||||||
|
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||||
|
|
||||||
|
RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
|
||||||
|
RUN conda activate ${VENV} && cd CosyVoice && \
|
||||||
|
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
|
||||||
|
|
||||||
|
WORKDIR /workspace/CosyVoice
|
||||||
112
example.py
Normal file
112
example.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import sys
|
||||||
|
sys.path.append('third_party/Matcha-TTS')
|
||||||
|
from cosyvoice.cli.cosyvoice import AutoModel
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
def cosyvoice_example():
|
||||||
|
""" CosyVoice Usage, check https://fun-audio-llm.github.io/ for more details
|
||||||
|
"""
|
||||||
|
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-SFT')
|
||||||
|
# sft usage
|
||||||
|
print(cosyvoice.list_available_spks())
|
||||||
|
# change stream=True for chunk stream inference
|
||||||
|
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
|
||||||
|
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M')
|
||||||
|
# zero_shot usage
|
||||||
|
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
|
||||||
|
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
# cross_lingual usage, <|zh|><|en|><|ja|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
||||||
|
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.',
|
||||||
|
'./asset/cross_lingual_prompt.wav')):
|
||||||
|
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
# vc usage
|
||||||
|
for i, j in enumerate(cosyvoice.inference_vc('./asset/cross_lingual_prompt.wav', './asset/zero_shot_prompt.wav')):
|
||||||
|
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||||
|
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
||||||
|
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男',
|
||||||
|
'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.<|endofprompt|>')):
|
||||||
|
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def cosyvoice2_example():
|
||||||
|
""" CosyVoice2 Usage, check https://funaudiollm.github.io/cosyvoice2/ for more details
|
||||||
|
"""
|
||||||
|
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice2-0.5B')
|
||||||
|
|
||||||
|
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
||||||
|
# zero_shot usage
|
||||||
|
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
|
||||||
|
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
# save zero_shot spk for future usage
|
||||||
|
assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', 'my_zero_shot_spk') is True
|
||||||
|
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk')):
|
||||||
|
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
cosyvoice.save_spkinfo()
|
||||||
|
|
||||||
|
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
|
||||||
|
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', './asset/zero_shot_prompt.wav')):
|
||||||
|
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
# instruct usage
|
||||||
|
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话<|endofprompt|>', './asset/zero_shot_prompt.wav')):
|
||||||
|
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
# bistream usage, you can use generator as input, this is useful when using text llm model as input
|
||||||
|
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
|
||||||
|
def text_generator():
|
||||||
|
yield '收到好友从远方寄来的生日礼物,'
|
||||||
|
yield '那份意外的惊喜与深深的祝福'
|
||||||
|
yield '让我心中充满了甜蜜的快乐,'
|
||||||
|
yield '笑容如花儿般绽放。'
|
||||||
|
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
|
||||||
|
torchaudio.save('zero_shot_bistream_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def cosyvoice3_example():
|
||||||
|
""" CosyVoice3 Usage, check https://funaudiollm.github.io/cosyvoice3/ for more details
|
||||||
|
"""
|
||||||
|
cosyvoice = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B')
|
||||||
|
# zero_shot usage
|
||||||
|
for i, j in enumerate(cosyvoice.inference_zero_shot('八百标兵奔北坡,北坡炮兵并排跑,炮兵怕把标兵碰,标兵怕碰炮兵炮。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
|
||||||
|
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||||
|
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L280
|
||||||
|
for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>[breath]因为他们那一辈人[breath]在乡里面住的要习惯一点,[breath]邻居都很活络,[breath]嗯,都很熟悉。[breath]',
|
||||||
|
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||||
|
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
# instruct usage, for supported control, check cosyvoice/utils/common.py#L28
|
||||||
|
for i, j in enumerate(cosyvoice.inference_instruct2('好少咯,一般系放嗰啲国庆啊,中秋嗰啲可能会咯。', 'You are a helpful assistant. 请用广东话表达。<|endofprompt|>',
|
||||||
|
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||||
|
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', 'You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>',
|
||||||
|
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||||
|
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
# hotfix usage
|
||||||
|
for i, j in enumerate(cosyvoice.inference_zero_shot('高管也通过电话、短信、微信等方式对报道[j][ǐ]予好评。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
|
||||||
|
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||||
|
torchaudio.save('hotfix_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
# NOTE for Japanese usage, you must translate it to katakana.
|
||||||
|
# 歴史的世界においては、過去は単に過ぎ去ったものではない、プラトンのいう如く非有が有である。 -> レキシ テキ セカイ ニ オイ テ ワ、カコ ワ タンニ スギサッ タ モノ デ ワ ナイ、プラトン ノ イウ ゴトク ヒ ユー ガ ユー デ アル。
|
||||||
|
for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>レキシ テキ セカイ ニ オイ テ ワ、カコ ワ タンニ スギサッ タ モノ デ ワ ナイ、プラトン ノ イウ ゴトク ヒ ユー ガ ユー デ アル。',
|
||||||
|
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||||
|
torchaudio.save('japanese_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# cosyvoice_example()
|
||||||
|
# cosyvoice2_example()
|
||||||
|
cosyvoice3_example()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
6
examples/grpo/cosyvoice2/Dockerfile
Normal file
6
examples/grpo/cosyvoice2/Dockerfile
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
|
||||||
|
COPY requirements.txt /myworkspace/requirements.txt
|
||||||
|
RUN pip install -r /myworkspace/requirements.txt
|
||||||
|
RUN pip install -U nvidia-pytriton
|
||||||
|
RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e .
|
||||||
|
RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e .
|
||||||
125
examples/grpo/cosyvoice2/README.md
Normal file
125
examples/grpo/cosyvoice2/README.md
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# CosyVoice2 LLM Reinforcement Learning Recipe
|
||||||
|
|
||||||
|
This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Environment Setup](#environment-setup)
|
||||||
|
- [Data Preparation](#data-preparation)
|
||||||
|
- [Reward Function & ASR Server](#reward-function--asr-server)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Export Model](#export-model)
|
||||||
|
- [Results](#results)
|
||||||
|
- [Acknowledgement](#acknowledgement)
|
||||||
|
|
||||||
|
## Environment Setup
|
||||||
|
We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile.
|
||||||
|
```bash
|
||||||
|
docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
|
||||||
|
```
|
||||||
|
If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally.
|
||||||
|
|
||||||
|
## Data Preparation
|
||||||
|
|
||||||
|
`prepare_data.py` expects a JSON/JSONL file with at least the following schema:
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
"text": "An example sentence to be synthesized."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face.
|
||||||
|
|
||||||
|
Stage `0` converts raw JSONL files into the parquet format expected by veRL:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 0 0
|
||||||
|
```
|
||||||
|
Create two JSONL files—`train.jsonl` and `test.jsonl`.
|
||||||
|
The script will then generate two Parquet files:
|
||||||
|
|
||||||
|
```
|
||||||
|
data/parquet_tiny/train.parquet
|
||||||
|
data/parquet_tiny/test.parquet
|
||||||
|
```
|
||||||
|
|
||||||
|
Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
|
||||||
|
|
||||||
|
|
||||||
|
## Reward Function & ASR Server
|
||||||
|
|
||||||
|
To compute rewards, we run a lightweight server that:
|
||||||
|
|
||||||
|
1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
|
||||||
|
2. Transcribes the waveform with **SenseVoice** ASR.
|
||||||
|
3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1.
|
||||||
|
|
||||||
|
Start the server (stage `1`) in a dedicated terminal or on a separate GPU:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 1 1
|
||||||
|
# Triton server listens on ports 8000/8001/8002
|
||||||
|
```
|
||||||
|
|
||||||
|
The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
Run stage `2` to start GRPO training:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 2 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Key CLI arguments passed to `verl.trainer.main_ppo`:
|
||||||
|
|
||||||
|
* `algorithm.adv_estimator=grpo` – use GRPO instead of PPO.
|
||||||
|
* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet`
|
||||||
|
* `custom_reward_function.path=reward_tts.py` – custom reward function described above.
|
||||||
|
|
||||||
|
Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware.
|
||||||
|
> [!TIP]
|
||||||
|
> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model.
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 4 4
|
||||||
|
```
|
||||||
|
|
||||||
|
This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> The script also supports the Seed-TTS test set by setting `dataset=test_zh`.
|
||||||
|
|
||||||
|
## Export Model
|
||||||
|
|
||||||
|
To use the RL-trained model with the official CosyVoice repository:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash run.sh 5 5
|
||||||
|
```
|
||||||
|
|
||||||
|
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
|
||||||
|
> [!TIP]
|
||||||
|
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment |
|
||||||
|
|-------|------------------------|------------------------------|---------|
|
||||||
|
| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) |
|
||||||
|
| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model |
|
||||||
|
|
||||||
|
## Acknowledgement
|
||||||
|
|
||||||
|
This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo).
|
||||||
71
examples/grpo/cosyvoice2/huggingface_to_pretrained.py
Normal file
71
examples/grpo/cosyvoice2/huggingface_to_pretrained.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt
|
||||||
|
"""
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
import torch
|
||||||
|
from safetensors import safe_open
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-cosyvoice2-llm-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The RL trained CosyVoice2 model path in HuggingFace format",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-path",
|
||||||
|
type=str,
|
||||||
|
default="./llm.pt",
|
||||||
|
help="The path to save the llm.pt",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path)
|
||||||
|
speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>")
|
||||||
|
cosyvoice2_token_size = 6561 + 3
|
||||||
|
llm_embedding_vocab_size = 2
|
||||||
|
|
||||||
|
hf_tensors = {}
|
||||||
|
with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k.startswith("lm_head.bias"):
|
||||||
|
# RL trained model disable bias for lm_head
|
||||||
|
continue
|
||||||
|
new_k = "llm.model." + k
|
||||||
|
hf_tensors[new_k] = f.get_tensor(k)
|
||||||
|
if k.startswith("lm_head"):
|
||||||
|
hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
|
||||||
|
hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0])
|
||||||
|
if k.startswith("model.embed_tokens"):
|
||||||
|
hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
|
||||||
|
hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size]
|
||||||
|
|
||||||
|
# use tie_word_embeddings=True
|
||||||
|
hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936]
|
||||||
|
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
|
||||||
|
|
||||||
|
torch.save(hf_tensors, args.output_path)
|
||||||
397
examples/grpo/cosyvoice2/infer_dataset.py
Normal file
397
examples/grpo/cosyvoice2/infer_dataset.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Example Usage
|
||||||
|
dataset=zero_shot_zh
|
||||||
|
output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts
|
||||||
|
|
||||||
|
token2wav_path=/workspace/CosyVoice2-0.5B
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
torchrun --nproc_per_node=8 \
|
||||||
|
infer_dataset.py \
|
||||||
|
--output-dir $output_dir \
|
||||||
|
--llm-model-name-or-path $llm_path/merged_hf_model \
|
||||||
|
--token2wav-path $token2wav_path \
|
||||||
|
--split-name ${dataset} || exit 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
from cosyvoice.utils.file_utils import load_wav
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
import soundfile as sf
|
||||||
|
import s3tokenizer
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
try:
|
||||||
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
def audio_decode_cosyvoice2(
|
||||||
|
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate audio from tokens with optional tone and prompt embedding.
|
||||||
|
"""
|
||||||
|
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
|
||||||
|
"empty", prompt_text, prompt_speech_16k, 24000
|
||||||
|
)
|
||||||
|
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||||
|
token=audio_tokens.to(codec_decoder.model.device),
|
||||||
|
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_token_len=torch.tensor(
|
||||||
|
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
|
||||||
|
).to(codec_decoder.model.device),
|
||||||
|
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
|
||||||
|
finalize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||||
|
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio_hat
|
||||||
|
|
||||||
|
|
||||||
|
def extract_speech_ids(speech_tokens_str):
|
||||||
|
"""Extract speech IDs from token strings like <|s_23456|>"""
|
||||||
|
speech_ids = []
|
||||||
|
for token_str in speech_tokens_str:
|
||||||
|
if token_str.startswith('<|s_') and token_str.endswith('|>'):
|
||||||
|
num_str = token_str[4:-2]
|
||||||
|
num = int(num_str)
|
||||||
|
speech_ids.append(num)
|
||||||
|
else:
|
||||||
|
print(f"Unexpected token: {token_str}")
|
||||||
|
return speech_ids
|
||||||
|
|
||||||
|
|
||||||
|
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||||
|
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||||
|
speech_id_str = ""
|
||||||
|
for token in cosy2_tokens:
|
||||||
|
speech_id_str += f"<|s_{token}|>"
|
||||||
|
return speech_id_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir", required=True, type=str, help="dir to save result"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="batch size (per-device) for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers", type=int, default=1, help="workers for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefetch", type=int, default=5, help="prefetch for dataloader"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model-name-or-path",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="LLM model path (includes both model and tokenizer)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token2wav-path",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="CosyVoice2 token2wav model path",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-text",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The prompt text for CosyVoice2",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-speech-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The path to the prompt speech for CosyVoice2",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-p",
|
||||||
|
type=float,
|
||||||
|
default=0.95,
|
||||||
|
help="top p for sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="temperature for sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="top k for sampling",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||||
|
"""Simplified data collator for batch_size=1 processing"""
|
||||||
|
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
||||||
|
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
||||||
|
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||||
|
mels, prompt_audio_cosy2tokens_list = [], []
|
||||||
|
for item in batch:
|
||||||
|
prompt_text, target_text = (
|
||||||
|
item["prompt_text"],
|
||||||
|
item["target_text"],
|
||||||
|
)
|
||||||
|
prompt_text_list.append(prompt_text)
|
||||||
|
# Combine prompt and target text
|
||||||
|
full_text = prompt_text + target_text
|
||||||
|
|
||||||
|
# get prompt audio for CosyVoice2 (convert to 16kHz)
|
||||||
|
ref_audio_org, ref_sr = (
|
||||||
|
item["prompt_audio"]["array"],
|
||||||
|
item["prompt_audio"]["sampling_rate"],
|
||||||
|
)
|
||||||
|
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
|
||||||
|
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
|
||||||
|
print(ref_audio_org.shape)
|
||||||
|
|
||||||
|
if ref_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||||
|
ref_audio = resampler(ref_audio_org)
|
||||||
|
else:
|
||||||
|
ref_audio = ref_audio_org
|
||||||
|
|
||||||
|
prompt_audio_list.append(ref_audio)
|
||||||
|
|
||||||
|
if "prompt_audio_cosy2_tokens" in item:
|
||||||
|
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
|
||||||
|
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
|
||||||
|
else:
|
||||||
|
# convert to float first
|
||||||
|
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
|
||||||
|
|
||||||
|
if len(mels) > 0:
|
||||||
|
mels, mels_lens = s3tokenizer.padding(mels)
|
||||||
|
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
||||||
|
for i in range(len(codes)):
|
||||||
|
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
||||||
|
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
|
||||||
|
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
||||||
|
# Create chat template for LLM generation
|
||||||
|
chat = [
|
||||||
|
{"role": "user", "content": full_text},
|
||||||
|
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
|
||||||
|
]
|
||||||
|
if 'system' in tokenizer.chat_template:
|
||||||
|
tokenizer.chat_template = TEMPLATE
|
||||||
|
input_ids = tokenizer.apply_chat_template(
|
||||||
|
chat,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors='pt',
|
||||||
|
continue_final_message=True
|
||||||
|
)
|
||||||
|
input_ids_list.append(input_ids.squeeze(0))
|
||||||
|
|
||||||
|
# For batch_size=1, no need to pad
|
||||||
|
if len(input_ids_list) == 1:
|
||||||
|
input_ids = input_ids_list[0].unsqueeze(0)
|
||||||
|
else:
|
||||||
|
# Handle batch > 1 if needed
|
||||||
|
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||||
|
input_ids_list = [
|
||||||
|
torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
|
||||||
|
for input_ids in input_ids_list
|
||||||
|
]
|
||||||
|
input_ids = torch.stack(input_ids_list)
|
||||||
|
|
||||||
|
ids = [item["id"] for item in batch]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"ids": ids,
|
||||||
|
"prompt_text": prompt_text_list,
|
||||||
|
"prompt_audio_list": prompt_audio_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed():
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
rank = int(os.environ.get("RANK", 0))
|
||||||
|
print(
|
||||||
|
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||||
|
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
return world_size, local_rank, rank
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
world_size, local_rank, rank = init_distributed()
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
|
||||||
|
# Load LLM model and tokenizer directly
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||||
|
model.eval()
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
cosyvoice_codec = CosyVoice2(
|
||||||
|
args.token2wav_path, load_jit=True, load_trt=True, fp16=True
|
||||||
|
)
|
||||||
|
if args.prompt_speech_path:
|
||||||
|
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
|
||||||
|
else:
|
||||||
|
prompt_speech_16k = None
|
||||||
|
s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None
|
||||||
|
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
|
||||||
|
dataset = load_dataset(
|
||||||
|
dataset_name,
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
prefetch_factor=args.prefetch,
|
||||||
|
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
total_steps = len(dataset)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||||
|
|
||||||
|
for batch in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
input_ids = batch["input_ids"].to(device)
|
||||||
|
|
||||||
|
# Generate speech tokens using LLM
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids,
|
||||||
|
max_new_tokens=2048, # Max length for generation
|
||||||
|
do_sample=True,
|
||||||
|
top_p=args.top_p,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_k=args.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process each sample in the batch
|
||||||
|
for i in range(len(batch["ids"])):
|
||||||
|
# Extract generated tokens (excluding input)
|
||||||
|
input_length = input_ids[i].shape[0]
|
||||||
|
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
|
||||||
|
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Extract speech IDs from token strings like <|s_23456|>
|
||||||
|
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||||
|
|
||||||
|
if len(speech_ids) == 0:
|
||||||
|
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert to tensor for CosyVoice2
|
||||||
|
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
|
||||||
|
|
||||||
|
if args.prompt_text is not None:
|
||||||
|
current_prompt_text = args.prompt_text
|
||||||
|
current_prompt_audio = prompt_speech_16k
|
||||||
|
else:
|
||||||
|
current_prompt_text = batch["prompt_text"][i]
|
||||||
|
current_prompt_audio = batch["prompt_audio_list"][i]
|
||||||
|
|
||||||
|
if current_prompt_audio is not None:
|
||||||
|
# Generate audio using CosyVoice2
|
||||||
|
audio_hat = audio_decode_cosyvoice2(
|
||||||
|
audio_tokens,
|
||||||
|
current_prompt_text,
|
||||||
|
current_prompt_audio,
|
||||||
|
cosyvoice_codec,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to numpy and save
|
||||||
|
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||||
|
target_sample_rate = 24000
|
||||||
|
|
||||||
|
utt = batch["ids"][i]
|
||||||
|
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
|
||||||
|
|
||||||
|
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
|
||||||
|
else:
|
||||||
|
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.update(world_size * len(batch["ids"]))
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
86
examples/grpo/cosyvoice2/prepare_data.py
Normal file
86
examples/grpo/cosyvoice2/prepare_data.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Preprocess the Text to Speech dataset to parquet format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
from verl.utils.hdfs_io import copy, makedirs
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
|
||||||
|
parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file")
|
||||||
|
parser.add_argument("--local_dir", default=None, required=True)
|
||||||
|
parser.add_argument("--hdfs_dir", default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load datasets from local JSON files
|
||||||
|
train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train']
|
||||||
|
test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train']
|
||||||
|
|
||||||
|
# add a row to each data item that represents a unique id
|
||||||
|
def make_map_fn(split):
|
||||||
|
def process_fn(example, idx):
|
||||||
|
text = example.pop("text")
|
||||||
|
|
||||||
|
# use cosyvoice2 official huggingface compatible checkpoint template
|
||||||
|
question = text
|
||||||
|
answer = ""
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"data_source": f"{args.train_file}_{args.test_file}", # Use file names as data source
|
||||||
|
"prompt": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": question,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": answer,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"ability": "text-to-speech",
|
||||||
|
"reward_model": {"style": "rule", "ground_truth": text},
|
||||||
|
"extra_info": {
|
||||||
|
"split": split,
|
||||||
|
"index": idx,
|
||||||
|
"text": text,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
return process_fn
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
|
||||||
|
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
|
||||||
|
|
||||||
|
local_dir = args.local_dir
|
||||||
|
hdfs_dir = args.hdfs_dir
|
||||||
|
|
||||||
|
print(train_dataset)
|
||||||
|
print(test_dataset)
|
||||||
|
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
|
||||||
|
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
|
||||||
|
|
||||||
|
if hdfs_dir is not None:
|
||||||
|
makedirs(hdfs_dir)
|
||||||
|
|
||||||
|
copy(src=local_dir, dst=hdfs_dir)
|
||||||
133
examples/grpo/cosyvoice2/pretrained_to_huggingface.py
Normal file
133
examples/grpo/cosyvoice2/pretrained_to_huggingface.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage: Instruct TTS
|
||||||
|
python3 infer.py \
|
||||||
|
--token2wav-path /workspace/CosyVoice2-0.5B \
|
||||||
|
--prompt-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
||||||
|
--prompt-speech-path ./assets/prompt_audio.wav \
|
||||||
|
--model-path ./transformers_cosyvoice2_llm \
|
||||||
|
--input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
|
||||||
|
"""
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
import sys
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained-cosyvoice2-path",
|
||||||
|
type=str,
|
||||||
|
default="/workspace/CosyVoice2-0.5B",
|
||||||
|
help="Token2Wav path, default to %(default)r",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=str,
|
||||||
|
default='./transformers_cosyvoice2_llm',
|
||||||
|
help="The path to save the model",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
cosy2_model = CosyVoice2(
|
||||||
|
args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = cosy2_model.model.llm.llm.model
|
||||||
|
|
||||||
|
speech_embedding = cosy2_model.model.llm.speech_embedding
|
||||||
|
llm_decoder = cosy2_model.model.llm.llm_decoder
|
||||||
|
llm_embedding = cosy2_model.model.llm.llm_embedding
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN")
|
||||||
|
special_tokens = {
|
||||||
|
'eos_token': '<|endoftext|>',
|
||||||
|
'pad_token': '<|endoftext|>',
|
||||||
|
'additional_special_tokens': [
|
||||||
|
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||||
|
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||||
|
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||||
|
'[quick_breath]',
|
||||||
|
"<laughter>", "</laughter>",
|
||||||
|
"[hissing]", "[sigh]", "[vocalized-noise]",
|
||||||
|
"[lipsmack]", "[mn]"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
tokenizer.add_special_tokens(special_tokens)
|
||||||
|
|
||||||
|
original_tokenizer_vocab_size = len(tokenizer)
|
||||||
|
cosyvoice2_token_size = 6561
|
||||||
|
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
|
||||||
|
"<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>"
|
||||||
|
]
|
||||||
|
num_added_tokens = tokenizer.add_tokens(new_tokens)
|
||||||
|
|
||||||
|
llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128)
|
||||||
|
vocab_size = llm.get_input_embeddings().weight.shape[0]
|
||||||
|
|
||||||
|
feature_size = speech_embedding.embedding_dim
|
||||||
|
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# set the weight and bias of the new lm_head to 0
|
||||||
|
new_lm_head.weight.data.zero_()
|
||||||
|
# make bias value -inf
|
||||||
|
new_lm_head.bias.data.fill_(-float('inf'))
|
||||||
|
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
|
||||||
|
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
|
||||||
|
|
||||||
|
llm.lm_head = new_lm_head
|
||||||
|
input_embeddings = llm.get_input_embeddings()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
|
||||||
|
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
|
||||||
|
|
||||||
|
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
|
||||||
|
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
|
||||||
|
original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
|
||||||
|
llm.generation_config.eos_token_id = eos_token_ids
|
||||||
|
llm.generation_config.temperature = 1.0
|
||||||
|
llm.generation_config.top_p = 0.8
|
||||||
|
llm.generation_config.top_k = 25
|
||||||
|
|
||||||
|
llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size
|
||||||
|
llm.config.vocab_size = vocab_size
|
||||||
|
llm.config.tie_word_embeddings = False
|
||||||
|
llm.config.use_bias = True
|
||||||
|
llm.to(torch.bfloat16)
|
||||||
|
llm.save_pretrained(args.save_path)
|
||||||
|
|
||||||
|
TEMPLATE = (
|
||||||
|
"{%- for message in messages %}"
|
||||||
|
"{%- if message['role'] == 'user' %}"
|
||||||
|
"{{- '<|sos|>' + message['content'] + '<|task_id|>' }}"
|
||||||
|
"{%- elif message['role'] == 'assistant' %}"
|
||||||
|
"{{- message['content']}}"
|
||||||
|
"{%- endif %}"
|
||||||
|
"{%- endfor %}"
|
||||||
|
)
|
||||||
|
tokenizer.chat_template = TEMPLATE
|
||||||
|
tokenizer.save_pretrained(args.save_path)
|
||||||
31
examples/grpo/cosyvoice2/requirements.txt
Normal file
31
examples/grpo/cosyvoice2/requirements.txt
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
conformer==0.3.2
|
||||||
|
diffusers==0.29.0
|
||||||
|
gdown==5.1.0
|
||||||
|
gradio
|
||||||
|
hydra-core==1.3.2
|
||||||
|
HyperPyYAML==1.2.2
|
||||||
|
inflect==7.3.1
|
||||||
|
librosa==0.10.2
|
||||||
|
lightning==2.2.4
|
||||||
|
matplotlib==3.7.5
|
||||||
|
modelscope==1.15.0
|
||||||
|
networkx==3.1
|
||||||
|
omegaconf==2.3.0
|
||||||
|
onnx==1.16.0
|
||||||
|
onnxruntime-gpu==1.18.0
|
||||||
|
protobuf==4.25
|
||||||
|
pydantic==2.7.0
|
||||||
|
pyworld==0.3.4
|
||||||
|
rich==13.7.1
|
||||||
|
soundfile==0.12.1
|
||||||
|
tensorboard==2.14.0
|
||||||
|
wget==3.2
|
||||||
|
WeTextProcessing==1.0.3
|
||||||
|
s3tokenizer
|
||||||
|
tensorrt
|
||||||
|
sherpa_onnx
|
||||||
|
jiwer
|
||||||
|
zhon
|
||||||
|
numpy==1.25.2
|
||||||
|
pypinyin
|
||||||
|
openai-whisper
|
||||||
233
examples/grpo/cosyvoice2/reward_tts.py
Normal file
233
examples/grpo/cosyvoice2/reward_tts.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Reward calculation for CosyVoice2-0.5B.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ids(token_str: str) -> List[int]:
|
||||||
|
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
|
||||||
|
|
||||||
|
|
||||||
|
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
|
||||||
|
"""Send token IDs and ground-truth text to the Triton server and get reward."""
|
||||||
|
|
||||||
|
tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1)
|
||||||
|
lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32)
|
||||||
|
|
||||||
|
gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"inputs": [
|
||||||
|
{
|
||||||
|
"name": "TOKENS",
|
||||||
|
"shape": list(tokens_arr.shape),
|
||||||
|
"datatype": "INT32",
|
||||||
|
"data": tokens_arr.tolist(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "TOKEN_LENS",
|
||||||
|
"shape": list(lens_arr.shape),
|
||||||
|
"datatype": "INT32",
|
||||||
|
"data": lens_arr.tolist(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "GT_TEXT",
|
||||||
|
"shape": [1, 1],
|
||||||
|
"datatype": "BYTES",
|
||||||
|
"data": [ground_truth],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
rsp = requests.post(
|
||||||
|
REWARD_SERVER_URL,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json=payload,
|
||||||
|
timeout=timeout,
|
||||||
|
verify=False,
|
||||||
|
params={"request_id": "0"},
|
||||||
|
)
|
||||||
|
rsp.raise_for_status()
|
||||||
|
result = rsp.json()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Reward is returned as the first output
|
||||||
|
return float(result["outputs"][0]["data"][0])
|
||||||
|
except (KeyError, IndexError, TypeError):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def compute_score(
|
||||||
|
data_source: str,
|
||||||
|
solution_str: str,
|
||||||
|
ground_truth: str,
|
||||||
|
extra_info: dict | None = None,
|
||||||
|
*,
|
||||||
|
debug_dump: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""Return reward in [0, 1] using the Triton ASR service.
|
||||||
|
|
||||||
|
The reward is based on the pinyin-level WER between the ASR transcript
|
||||||
|
produced from *solution_str* and the provided *ground_truth* text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Decode token IDs
|
||||||
|
ids = _parse_ids(solution_str)
|
||||||
|
|
||||||
|
# Query remote server for reward
|
||||||
|
try:
|
||||||
|
reward = _remote_reward(ids, ground_truth)
|
||||||
|
except Exception as e:
|
||||||
|
reward = 0.0
|
||||||
|
|
||||||
|
if debug_dump:
|
||||||
|
print(
|
||||||
|
f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m"
|
||||||
|
)
|
||||||
|
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
# CLI quick test
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
"""Parse command line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Test TTS CER scoring with data from JSONL file",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--input", "-i",
|
||||||
|
type=str,
|
||||||
|
default="data/emilia_zh-cosy-tiny-test.jsonl",
|
||||||
|
help="Path to input JSONL file"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-samples", "-n",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum number of samples to process (default: all)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-interactive",
|
||||||
|
action="store_true",
|
||||||
|
help="Run in non-interactive mode (process all samples without prompts)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable debug mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def load_jsonl(file_path: str):
|
||||||
|
"""Load data from jsonl file."""
|
||||||
|
data = []
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
data.append(json.loads(line.strip()))
|
||||||
|
return data
|
||||||
|
|
||||||
|
def code_to_solution_str(code_list: List[int]) -> str:
|
||||||
|
"""Convert code list to solution string format."""
|
||||||
|
return ''.join([f"<|s_{code}|>" for code in code_list])
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load data from jsonl file
|
||||||
|
print(f"Loading data from: {args.input}")
|
||||||
|
data_list = load_jsonl(args.input)
|
||||||
|
print(f"Loaded {len(data_list)} samples")
|
||||||
|
|
||||||
|
# Limit samples if specified
|
||||||
|
if args.max_samples is not None:
|
||||||
|
data_list = data_list[:args.max_samples]
|
||||||
|
print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
|
||||||
|
|
||||||
|
# Process each sample
|
||||||
|
begin_time = time.time()
|
||||||
|
for i, sample in enumerate(data_list):
|
||||||
|
print(f"\n--- Sample {i+1}/{len(data_list)} ---")
|
||||||
|
print(f"Index: {sample.get('index', 'unknown')}")
|
||||||
|
print(f"Text: {sample['text']}")
|
||||||
|
|
||||||
|
# Extract required fields
|
||||||
|
code_list = sample['code']
|
||||||
|
ground_truth = sample['text']
|
||||||
|
data_source = sample.get('index', f'sample_{i}') # Use index as data_source
|
||||||
|
|
||||||
|
# Convert code list to solution string
|
||||||
|
solution_str = code_to_solution_str(code_list)
|
||||||
|
print(f"Solution tokens: {len(code_list)} tokens")
|
||||||
|
if args.debug:
|
||||||
|
print(f"Solution string: {solution_str}")
|
||||||
|
else:
|
||||||
|
print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
|
||||||
|
|
||||||
|
# Call compute_score function
|
||||||
|
try:
|
||||||
|
score = compute_score(
|
||||||
|
data_source=data_source,
|
||||||
|
solution_str=solution_str,
|
||||||
|
ground_truth=ground_truth,
|
||||||
|
extra_info=None,
|
||||||
|
debug_dump=args.debug
|
||||||
|
)
|
||||||
|
print(f"Final Score: {score:.4f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error computing score: {e}")
|
||||||
|
|
||||||
|
# Ask user if they want to continue (for interactive mode)
|
||||||
|
if not args.no_interactive and i < len(data_list) - 1:
|
||||||
|
try:
|
||||||
|
response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower()
|
||||||
|
if response == 'q':
|
||||||
|
break
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStopped by user")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"\nProcessed {min(i+1, len(data_list))} samples")
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Time taken: {end_time - begin_time} seconds")
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: File not found - {args.input}")
|
||||||
|
print("Please check the file path or use --input to specify correct path")
|
||||||
|
print("Run with --help for usage information")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
159
examples/grpo/cosyvoice2/run.sh
Normal file
159
examples/grpo/cosyvoice2/run.sh
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=4
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
export PYTHONPATH=/workspace/CosyVoice
|
||||||
|
model_scope_model_path=./CosyVoice2-0.5B
|
||||||
|
sft_model_path=./transformers_cosyvoice2_llm
|
||||||
|
|
||||||
|
if [ $stage -le -2 ] && [ $stop_stage -ge -2 ]; then
|
||||||
|
log "stage -2: install dependencies locally if pre-built docker image is not available"
|
||||||
|
conda create -n cosyvoice2 python=3.10 -y
|
||||||
|
conda activate cosyvoice2
|
||||||
|
# install verl
|
||||||
|
git clone https://github.com/yuekaizhang/verl.git -b thread
|
||||||
|
cd verl
|
||||||
|
USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh
|
||||||
|
pip install --no-deps -e .
|
||||||
|
cd -
|
||||||
|
# install requirements
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install -U nvidia-pytriton
|
||||||
|
git clone https://github.com/yuekaizhang/PytritonSenseVoice.git && cd PytritonSenseVoice && pip install -e .
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
|
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
|
||||||
|
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
|
||||||
|
python3 pretrained_to_huggingface.py \
|
||||||
|
--pretrained-cosyvoice2-path $model_scope_model_path \
|
||||||
|
--save-path $sft_model_path
|
||||||
|
|
||||||
|
# Or, you could use the following command to download the huggingface compatible checkpoint
|
||||||
|
# huggingface-cli download --local-dir $sft_model_path yuekai/cosyvoice2_llm
|
||||||
|
|
||||||
|
# Note: we remove the lm_head's bias to make it compatible with the Qwen2.5-0.5B model in Transformers.
|
||||||
|
fi
|
||||||
|
|
||||||
|
data_dir=data/parquet_aishell3
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "stage 0: prepare data into verl format"
|
||||||
|
mkdir -p $data_dir
|
||||||
|
wget -O data/aishell-3.jsonl https://huggingface.co/datasets/SparkAudio/voxbox/resolve/main/metadata/aishell-3.jsonl
|
||||||
|
# total 88035 samples
|
||||||
|
head -n 80000 data/aishell-3.jsonl > data/train.jsonl
|
||||||
|
tail -n 100 data/aishell-3.jsonl > data/test.jsonl
|
||||||
|
python prepare_data.py \
|
||||||
|
--train_file data/train.jsonl \
|
||||||
|
--test_file data/test.jsonl \
|
||||||
|
--local_dir $data_dir
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "stage 1: start token2wav asr server for reward function"
|
||||||
|
python3 token2wav_asr_server.py --number-of-devices 8
|
||||||
|
fi
|
||||||
|
|
||||||
|
exp_name=official_llm_aishell3_grpo
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "stage 2: grpo train"
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||||
|
export MKL_SERVICE_FORCE_INTEL=TRUE
|
||||||
|
n_gpus_per_node=8
|
||||||
|
micro_batch_size=4
|
||||||
|
train_batch_size=32
|
||||||
|
python3 -m verl.trainer.main_ppo \
|
||||||
|
algorithm.adv_estimator=grpo \
|
||||||
|
data.train_files=$data_dir/train.parquet \
|
||||||
|
data.val_files=$data_dir/test.parquet \
|
||||||
|
data.train_batch_size=$train_batch_size \
|
||||||
|
data.max_prompt_length=1024 \
|
||||||
|
data.max_response_length=512 \
|
||||||
|
data.truncation='error' \
|
||||||
|
actor_rollout_ref.model.use_remove_padding=False \
|
||||||
|
actor_rollout_ref.model.path=$sft_model_path \
|
||||||
|
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||||
|
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
|
||||||
|
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \
|
||||||
|
actor_rollout_ref.actor.use_kl_loss=False \
|
||||||
|
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||||
|
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||||
|
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||||
|
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \
|
||||||
|
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||||
|
actor_rollout_ref.rollout.name=vllm \
|
||||||
|
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||||
|
actor_rollout_ref.rollout.do_sample=true \
|
||||||
|
actor_rollout_ref.rollout.temperature=0.8 \
|
||||||
|
actor_rollout_ref.rollout.top_p=0.95 \
|
||||||
|
actor_rollout_ref.rollout.top_k=25 \
|
||||||
|
actor_rollout_ref.rollout.n=4 \
|
||||||
|
actor_rollout_ref.rollout.val_kwargs.do_sample=true \
|
||||||
|
actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
|
||||||
|
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
|
||||||
|
actor_rollout_ref.rollout.val_kwargs.top_k=25 \
|
||||||
|
reward_model.reward_manager=prime \
|
||||||
|
custom_reward_function.path=reward_tts.py \
|
||||||
|
custom_reward_function.name=compute_score \
|
||||||
|
trainer.project_name='cosyvoice2_grpo' \
|
||||||
|
trainer.experiment_name=$exp_name \
|
||||||
|
trainer.logger=['console','wandb'] \
|
||||||
|
trainer.n_gpus_per_node=$n_gpus_per_node \
|
||||||
|
trainer.nnodes=1 \
|
||||||
|
trainer.save_freq=100 \
|
||||||
|
trainer.test_freq=100 \
|
||||||
|
trainer.resume_mode='auto' \
|
||||||
|
trainer.total_epochs=1 \
|
||||||
|
trainer.val_before_train=False
|
||||||
|
fi
|
||||||
|
|
||||||
|
steps=(100 200 300 400 500)
|
||||||
|
for step in ${steps[@]}; do
|
||||||
|
llm_path=./checkpoints/cosyvoice2_grpo/$exp_name/global_step_${step}
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "stage 3: merge the model"
|
||||||
|
python -m verl.model_merger merge \
|
||||||
|
--backend fsdp \
|
||||||
|
--local_dir $llm_path/actor \
|
||||||
|
--target_dir $llm_path/merged_hf_model || exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "stage 4: Test the model"
|
||||||
|
dataset=zero_shot_zh # from CosyVoice3 test set
|
||||||
|
# dataset=test_zh # from seed_tts test set
|
||||||
|
output_dir=./outputs_${exp_name}_${step}_${dataset}
|
||||||
|
|
||||||
|
token2wav_path=/workspace/CosyVoice2-0.5B
|
||||||
|
model_path=$llm_path/merged_hf_model
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
torchrun --nproc_per_node=8 \
|
||||||
|
infer_dataset.py \
|
||||||
|
--output-dir $output_dir \
|
||||||
|
--llm-model-name-or-path $model_path \
|
||||||
|
--token2wav-path $token2wav_path \
|
||||||
|
--split-name ${dataset} || exit 1
|
||||||
|
|
||||||
|
bash scripts/compute_wer.sh $output_dir ${dataset}
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
log "stage 5: Convert the RL trained model to CosyVoice repo format"
|
||||||
|
python3 huggingface_to_pretrained.py \
|
||||||
|
--hf-cosyvoice2-llm-path $llm_path/merged_hf_model \
|
||||||
|
--output-path /workspace/CosyVoice2-0.5B/llm-new.pt
|
||||||
|
# You need to manually move the llm-new.pt to overwrite /workspace/CosyVoice2-0.5B/llm.pt
|
||||||
|
# However, we found that the RL trained model accuracy would slightly drop after this conversion.
|
||||||
|
# Please be careful or use the huggingface format inference code.
|
||||||
|
fi
|
||||||
33
examples/grpo/cosyvoice2/scripts/compute_wer.sh
Normal file
33
examples/grpo/cosyvoice2/scripts/compute_wer.sh
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
wav_dir=$1
|
||||||
|
wav_files=$(ls $wav_dir/*.wav)
|
||||||
|
# if wav_files is empty, then exit
|
||||||
|
if [ -z "$wav_files" ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
split_name=$2
|
||||||
|
model_path=models/sherpa-onnx-paraformer-zh-2023-09-14
|
||||||
|
|
||||||
|
if [ ! -d $model_path ]; then
|
||||||
|
pip install sherpa-onnx
|
||||||
|
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||||
|
mkdir models
|
||||||
|
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models
|
||||||
|
fi
|
||||||
|
|
||||||
|
python3 scripts/offline-decode-files.py \
|
||||||
|
--tokens=$model_path/tokens.txt \
|
||||||
|
--paraformer=$model_path/model.int8.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=24000 \
|
||||||
|
--log-dir $wav_dir \
|
||||||
|
--feature-dim=80 \
|
||||||
|
--split-name $split_name \
|
||||||
|
--name sherpa_onnx \
|
||||||
|
$wav_files
|
||||||
|
|
||||||
|
# python3 scripts/paraformer-pytriton-client.py \
|
||||||
|
# --log-dir $wav_dir \
|
||||||
|
# --split-name $split_name \
|
||||||
|
# $wav_files
|
||||||
754
examples/grpo/cosyvoice2/scripts/offline-decode-files.py
Normal file
754
examples/grpo/cosyvoice2/scripts/offline-decode-files.py
Normal file
@@ -0,0 +1,754 @@
|
|||||||
|
# Copyright (c) 2023 by manyeyes
|
||||||
|
# Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||||
|
file(s) with a non-streaming model.
|
||||||
|
|
||||||
|
(1) For paraformer
|
||||||
|
|
||||||
|
./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=/path/to/tokens.txt \
|
||||||
|
--paraformer=/path/to/paraformer.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=16000 \
|
||||||
|
--feature-dim=80 \
|
||||||
|
/path/to/0.wav \
|
||||||
|
/path/to/1.wav
|
||||||
|
|
||||||
|
(2) For transducer models from icefall
|
||||||
|
|
||||||
|
./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=/path/to/tokens.txt \
|
||||||
|
--encoder=/path/to/encoder.onnx \
|
||||||
|
--decoder=/path/to/decoder.onnx \
|
||||||
|
--joiner=/path/to/joiner.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=16000 \
|
||||||
|
--feature-dim=80 \
|
||||||
|
/path/to/0.wav \
|
||||||
|
/path/to/1.wav
|
||||||
|
|
||||||
|
(3) For CTC models from NeMo
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
|
||||||
|
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(4) For Whisper models
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||||
|
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||||
|
--whisper-task=transcribe \
|
||||||
|
--num-threads=1 \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(5) For CTC models from WeNet
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||||
|
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(6) For tdnn models of the yesno recipe from icefall
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--sample-rate=8000 \
|
||||||
|
--feature-dim=23 \
|
||||||
|
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||||
|
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
|
||||||
|
|
||||||
|
Please refer to
|
||||||
|
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||||
|
to install sherpa-onnx and to download non-streaming pre-trained models
|
||||||
|
used in this file.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import wave
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Dict, Iterable, TextIO, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
from datasets import load_dataset
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
import kaldialign
|
||||||
|
from zhon.hanzi import punctuation
|
||||||
|
import string
|
||||||
|
punctuation_all = punctuation + string.punctuation
|
||||||
|
Pathlike = Union[str, Path]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_punctuation(text: str) -> str:
|
||||||
|
for x in punctuation_all:
|
||||||
|
if x == '\'':
|
||||||
|
continue
|
||||||
|
text = text.replace(x, '')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def store_transcripts(
|
||||||
|
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Save predicted results and reference transcripts to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
File to save the results to.
|
||||||
|
texts:
|
||||||
|
An iterable of tuples. The first element is the cur_id, the second is
|
||||||
|
the reference transcript and the third element is the predicted result.
|
||||||
|
If it is a multi-talker ASR system, the ref and hyp may also be lists of
|
||||||
|
strings.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
with open(filename, "w", encoding="utf8") as f:
|
||||||
|
for cut_id, ref, hyp in texts:
|
||||||
|
if char_level:
|
||||||
|
ref = list("".join(ref))
|
||||||
|
hyp = list("".join(hyp))
|
||||||
|
print(f"{cut_id}:\tref={ref}", file=f)
|
||||||
|
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||||
|
|
||||||
|
|
||||||
|
def write_error_stats(
|
||||||
|
f: TextIO,
|
||||||
|
test_set_name: str,
|
||||||
|
results: List[Tuple[str, str]],
|
||||||
|
enable_log: bool = True,
|
||||||
|
compute_CER: bool = False,
|
||||||
|
sclite_mode: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""Write statistics based on predicted results and reference transcripts.
|
||||||
|
|
||||||
|
It will write the following to the given file:
|
||||||
|
|
||||||
|
- WER
|
||||||
|
- number of insertions, deletions, substitutions, corrects and total
|
||||||
|
reference words. For example::
|
||||||
|
|
||||||
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||||
|
reference words (2337 correct)
|
||||||
|
|
||||||
|
- The difference between the reference transcript and predicted result.
|
||||||
|
An instance is given below::
|
||||||
|
|
||||||
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||||
|
|
||||||
|
The above example shows that the reference word is `EDISON`,
|
||||||
|
but it is predicted to `ADDISON` (a substitution error).
|
||||||
|
|
||||||
|
Another example is::
|
||||||
|
|
||||||
|
FOR THE FIRST DAY (SIR->*) I THINK
|
||||||
|
|
||||||
|
The reference word `SIR` is missing in the predicted
|
||||||
|
results (a deletion error).
|
||||||
|
results:
|
||||||
|
An iterable of tuples. The first element is the cut_id, the second is
|
||||||
|
the reference transcript and the third element is the predicted result.
|
||||||
|
enable_log:
|
||||||
|
If True, also print detailed WER to the console.
|
||||||
|
Otherwise, it is written only to the given file.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||||
|
ins: Dict[str, int] = defaultdict(int)
|
||||||
|
dels: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# `words` stores counts per word, as follows:
|
||||||
|
# corr, ref_sub, hyp_sub, ins, dels
|
||||||
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
|
num_corr = 0
|
||||||
|
ERR = "*"
|
||||||
|
|
||||||
|
if compute_CER:
|
||||||
|
for i, res in enumerate(results):
|
||||||
|
cut_id, ref, hyp = res
|
||||||
|
ref = list("".join(ref))
|
||||||
|
hyp = list("".join(hyp))
|
||||||
|
results[i] = (cut_id, ref, hyp)
|
||||||
|
|
||||||
|
for _cut_id, ref, hyp in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||||
|
for ref_word, hyp_word in ali:
|
||||||
|
if ref_word == ERR:
|
||||||
|
ins[hyp_word] += 1
|
||||||
|
words[hyp_word][3] += 1
|
||||||
|
elif hyp_word == ERR:
|
||||||
|
dels[ref_word] += 1
|
||||||
|
words[ref_word][4] += 1
|
||||||
|
elif hyp_word != ref_word:
|
||||||
|
subs[(ref_word, hyp_word)] += 1
|
||||||
|
words[ref_word][1] += 1
|
||||||
|
words[hyp_word][2] += 1
|
||||||
|
else:
|
||||||
|
words[ref_word][0] += 1
|
||||||
|
num_corr += 1
|
||||||
|
ref_len = sum([len(r) for _, r, _ in results])
|
||||||
|
sub_errs = sum(subs.values())
|
||||||
|
ins_errs = sum(ins.values())
|
||||||
|
del_errs = sum(dels.values())
|
||||||
|
tot_errs = sub_errs + ins_errs + del_errs
|
||||||
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||||
|
|
||||||
|
if enable_log:
|
||||||
|
logging.info(
|
||||||
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||||
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||||
|
f"{del_errs} del, {sub_errs} sub ]"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"%WER = {tot_err_rate}", file=f)
|
||||||
|
print(
|
||||||
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||||
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||||
|
f"words ({num_corr} correct)",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||||
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||||
|
for cut_id, ref, hyp in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
|
combine_successive_errors = True
|
||||||
|
if combine_successive_errors:
|
||||||
|
ali = [[[x], [y]] for x, y in ali]
|
||||||
|
for i in range(len(ali) - 1):
|
||||||
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||||
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||||
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||||
|
ali[i] = [[], []]
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
list(filter(lambda a: a != ERR, x)),
|
||||||
|
list(filter(lambda a: a != ERR, y)),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
ali = list(filter(lambda x: x != [[], []], ali))
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
ERR if x == [] else " ".join(x),
|
||||||
|
ERR if y == [] else " ".join(y),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{cut_id}:\t"
|
||||||
|
+ " ".join(
|
||||||
|
(
|
||||||
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||||
|
for ref_word, hyp_word in ali
|
||||||
|
)
|
||||||
|
),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||||
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("DELETIONS: count ref", file=f)
|
||||||
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||||
|
print(f"{count} {ref}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("INSERTIONS: count hyp", file=f)
|
||||||
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||||
|
print(f"{count} {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||||
|
for _, word, counts in sorted(
|
||||||
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
|
):
|
||||||
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||||
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||||
|
ref_count = corr + ref_sub + dels
|
||||||
|
hyp_count = corr + hyp_sub + ins
|
||||||
|
|
||||||
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="Path to tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hotwords-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The file containing hotwords, one words/phrases per line, like
|
||||||
|
HELLO WORLD
|
||||||
|
你好世界
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hotwords-score",
|
||||||
|
type=float,
|
||||||
|
default=1.5,
|
||||||
|
help="""
|
||||||
|
The hotword score of each token for biasing word/phrase. Used only if
|
||||||
|
--hotwords-file is given.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--modeling-unit",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
|
||||||
|
Used only when hotwords-file is given.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-vocab",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path to the bpe vocabulary, the bpe vocabulary is generated by
|
||||||
|
sentencepiece, you can also export the bpe vocabulary through a bpe model
|
||||||
|
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
|
||||||
|
and modeling-unit is bpe or cjkchar+bpe.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the joiner model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--paraformer",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from Paraformer",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nemo-ctc",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from NeMo CTC",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--wenet-ctc",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from WeNet CTC",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tdnn-model",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-threads",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of threads for neural network computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to whisper encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to whisper decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-language",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="""It specifies the spoken language in the input audio file.
|
||||||
|
Example values: en, fr, de, zh, jp.
|
||||||
|
Available languages for multilingual models can be found at
|
||||||
|
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||||
|
If not specified, we infer the language from the input audio file.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-task",
|
||||||
|
default="transcribe",
|
||||||
|
choices=["transcribe", "translate"],
|
||||||
|
type=str,
|
||||||
|
help="""For multilingual models, if you specify translate, the output
|
||||||
|
will be in English.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-tail-paddings",
|
||||||
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="""Number of tail padding frames.
|
||||||
|
We have removed the 30-second constraint from whisper, so you need to
|
||||||
|
choose the amount of tail padding frames by yourself.
|
||||||
|
Use -1 to use a default value for tail padding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="Valid values are greedy_search and modified_beam_search",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="True to show debug messages",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="""Sample rate of the feature extractor. Must match the one
|
||||||
|
expected by the model. Note: The input sound files can have a
|
||||||
|
different sample rate from this argument.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--feature-dim",
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help="Feature dimension. Must match the one expected by the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||||
|
"format with a single channel, and each sample has 16-bit, "
|
||||||
|
"i.e., int16_t. "
|
||||||
|
"The sample rate of the file can be arbitrary and does not need to "
|
||||||
|
"be 16 kHz",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The directory containing the input sound files to decode",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-dir",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The directory containing the input sound files to decode",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--label",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="wav_base_name label",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset related arguments for loading labels when label file is not provided
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-name",
|
||||||
|
type=str,
|
||||||
|
default="yuekai/seed_tts_cosy2",
|
||||||
|
help="Huggingface dataset name for loading labels",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
help="Dataset split name for loading labels",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def assert_file_exists(filename: str):
|
||||||
|
assert Path(filename).is_file(), (
|
||||||
|
f"{filename} does not exist!\n"
|
||||||
|
"Please refer to "
|
||||||
|
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
wave_filename:
|
||||||
|
Path to a wave file. It should be single channel and can be of type
|
||||||
|
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- A 1-D array of dtype np.float32 containing the samples,
|
||||||
|
which are normalized to the range [-1, 1].
|
||||||
|
- Sample rate of the wave file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples, sample_rate = sf.read(wave_filename, dtype="float32")
|
||||||
|
assert (
|
||||||
|
samples.ndim == 1
|
||||||
|
), f"Expected single channel, but got {samples.ndim} channels."
|
||||||
|
|
||||||
|
samples_float32 = samples.astype(np.float32)
|
||||||
|
|
||||||
|
return samples_float32, sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_text_alimeeting(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Text normalization similar to M2MeT challenge baseline.
|
||||||
|
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
text = text.replace('\u00A0', '') # test_hard
|
||||||
|
text = text.replace(" ", "")
|
||||||
|
text = text.replace("<sil>", "")
|
||||||
|
text = text.replace("<%>", "")
|
||||||
|
text = text.replace("<->", "")
|
||||||
|
text = text.replace("<$>", "")
|
||||||
|
text = text.replace("<#>", "")
|
||||||
|
text = text.replace("<_>", "")
|
||||||
|
text = text.replace("<space>", "")
|
||||||
|
text = text.replace("`", "")
|
||||||
|
text = text.replace("&", "")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
if re.search("[a-zA-Z]", text):
|
||||||
|
text = text.upper()
|
||||||
|
text = text.replace("A", "A")
|
||||||
|
text = text.replace("a", "A")
|
||||||
|
text = text.replace("b", "B")
|
||||||
|
text = text.replace("c", "C")
|
||||||
|
text = text.replace("k", "K")
|
||||||
|
text = text.replace("t", "T")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
text = text.replace("丶", "")
|
||||||
|
text = text.replace("。", "")
|
||||||
|
text = text.replace("、", "")
|
||||||
|
text = text.replace("?", "")
|
||||||
|
text = remove_punctuation(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert_file_exists(args.tokens)
|
||||||
|
assert args.num_threads > 0, args.num_threads
|
||||||
|
|
||||||
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||||
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
|
||||||
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||||
|
paraformer=args.paraformer,
|
||||||
|
tokens=args.tokens,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
feature_dim=args.feature_dim,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Started!")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
streams, results = [], []
|
||||||
|
total_duration = 0
|
||||||
|
|
||||||
|
for i, wave_filename in enumerate(args.sound_files):
|
||||||
|
assert_file_exists(wave_filename)
|
||||||
|
samples, sample_rate = read_wave(wave_filename)
|
||||||
|
duration = len(samples) / sample_rate
|
||||||
|
total_duration += duration
|
||||||
|
s = recognizer.create_stream()
|
||||||
|
s.accept_waveform(sample_rate, samples)
|
||||||
|
|
||||||
|
streams.append(s)
|
||||||
|
if i % 10 == 0:
|
||||||
|
recognizer.decode_streams(streams)
|
||||||
|
results += [s.result.text for s in streams]
|
||||||
|
streams = []
|
||||||
|
print(f"Processed {i} files")
|
||||||
|
# process the last batch
|
||||||
|
if streams:
|
||||||
|
recognizer.decode_streams(streams)
|
||||||
|
results += [s.result.text for s in streams]
|
||||||
|
end_time = time.time()
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
results_dict = {}
|
||||||
|
for wave_filename, result in zip(args.sound_files, results):
|
||||||
|
print(f"{wave_filename}\n{result}")
|
||||||
|
print("-" * 10)
|
||||||
|
wave_basename = Path(wave_filename).stem
|
||||||
|
results_dict[wave_basename] = result
|
||||||
|
|
||||||
|
elapsed_seconds = end_time - start_time
|
||||||
|
rtf = elapsed_seconds / total_duration
|
||||||
|
print(f"num_threads: {args.num_threads}")
|
||||||
|
print(f"decoding_method: {args.decoding_method}")
|
||||||
|
print(f"Wave duration: {total_duration:.3f} s")
|
||||||
|
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||||
|
print(
|
||||||
|
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load labels either from file or from dataset
|
||||||
|
labels_dict = {}
|
||||||
|
|
||||||
|
if args.label:
|
||||||
|
# Load labels from file (original functionality)
|
||||||
|
print(f"Loading labels from file: {args.label}")
|
||||||
|
with open(args.label, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
# fields = line.strip().split(" ")
|
||||||
|
# fields = [item for item in fields if item]
|
||||||
|
# assert len(fields) == 4
|
||||||
|
# prompt_text, prompt_audio, text, audio_path = fields
|
||||||
|
|
||||||
|
fields = line.strip().split("|")
|
||||||
|
fields = [item for item in fields if item]
|
||||||
|
assert len(fields) == 4
|
||||||
|
audio_path, prompt_text, prompt_audio, text = fields
|
||||||
|
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
|
||||||
|
else:
|
||||||
|
# Load labels from dataset (new functionality)
|
||||||
|
print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
|
||||||
|
if 'zero' in args.split_name:
|
||||||
|
dataset_name = "yuekai/CV3-Eval"
|
||||||
|
else:
|
||||||
|
dataset_name = "yuekai/seed_tts_cosy2"
|
||||||
|
dataset = load_dataset(
|
||||||
|
dataset_name,
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for item in dataset:
|
||||||
|
audio_id = item["id"]
|
||||||
|
labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
|
||||||
|
|
||||||
|
print(f"Loaded {len(labels_dict)} labels from dataset")
|
||||||
|
|
||||||
|
# Perform evaluation if labels are available
|
||||||
|
if labels_dict:
|
||||||
|
|
||||||
|
final_results = []
|
||||||
|
for key, value in results_dict.items():
|
||||||
|
if key in labels_dict:
|
||||||
|
final_results.append((key, labels_dict[key], value))
|
||||||
|
else:
|
||||||
|
print(f"Warning: No label found for {key}, skipping...")
|
||||||
|
|
||||||
|
if final_results:
|
||||||
|
store_transcripts(
|
||||||
|
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
|
||||||
|
)
|
||||||
|
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
|
||||||
|
write_error_stats(f, "test-set", final_results, enable_log=True)
|
||||||
|
|
||||||
|
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
|
||||||
|
print(f.readline()) # WER
|
||||||
|
print(f.readline()) # Detailed errors
|
||||||
|
else:
|
||||||
|
print("No matching labels found for evaluation")
|
||||||
|
else:
|
||||||
|
print("No labels available for evaluation")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
346
examples/grpo/cosyvoice2/token2wav_asr_server.py
Normal file
346
examples/grpo/cosyvoice2/token2wav_asr_server.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Pytriton server for token2wav conversion and ASR"""
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
from omnisense.models import OmniSenseVoiceSmall
|
||||||
|
from pytriton.proxy.types import Request
|
||||||
|
from pytriton.triton import Triton, TritonConfig
|
||||||
|
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
|
||||||
|
from pytriton.decorators import batch
|
||||||
|
import argparse
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import Any, List
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from scipy.signal import resample
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
from jiwer import wer
|
||||||
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||||
|
|
||||||
|
# Chinese text normalizer (cached globally)
|
||||||
|
zh_tn_model = ZhNormalizer(
|
||||||
|
cache_dir="./cache",
|
||||||
|
remove_erhua=False,
|
||||||
|
remove_interjections=False,
|
||||||
|
remove_puncts=True,
|
||||||
|
overwrite_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
|
||||||
|
logger = logging.getLogger("token2wav_asr_server")
|
||||||
|
|
||||||
|
|
||||||
|
class _ASR_Server:
|
||||||
|
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||||
|
|
||||||
|
def __init__(self, device_id: int):
|
||||||
|
self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
|
||||||
|
|
||||||
|
@batch
|
||||||
|
def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray):
|
||||||
|
"""
|
||||||
|
WAV: np.ndarray, WAV_LENS: np.ndarray
|
||||||
|
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
|
||||||
|
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
|
||||||
|
"""
|
||||||
|
logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape)
|
||||||
|
wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))]
|
||||||
|
|
||||||
|
results = self._model.transcribe_single_batch(
|
||||||
|
wavs,
|
||||||
|
language="zh",
|
||||||
|
textnorm="woitn",
|
||||||
|
)
|
||||||
|
texts = [result.text for result in results]
|
||||||
|
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||||
|
return {"TRANSCRIPTS": transcripts}
|
||||||
|
|
||||||
|
|
||||||
|
def audio_decode_cosyvoice2(
|
||||||
|
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate audio from tokens with optional tone and prompt embedding.
|
||||||
|
"""
|
||||||
|
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
|
||||||
|
"empty", prompt_text, prompt_speech_16k, 24000
|
||||||
|
)
|
||||||
|
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||||
|
token=audio_tokens.to(codec_decoder.model.device),
|
||||||
|
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_token_len=torch.tensor(
|
||||||
|
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
|
||||||
|
).to(codec_decoder.model.device),
|
||||||
|
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
|
||||||
|
codec_decoder.model.device
|
||||||
|
),
|
||||||
|
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
|
||||||
|
finalize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||||
|
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio_hat
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_prompt_from_dataset(dataset):
|
||||||
|
"""
|
||||||
|
Get random prompt text and speech from the pre-loaded dataset.
|
||||||
|
Returns (prompt_text, prompt_speech_16k)
|
||||||
|
"""
|
||||||
|
random_idx = random.randint(0, len(dataset) - 1)
|
||||||
|
sample = dataset[random_idx]
|
||||||
|
|
||||||
|
# Extract audio data
|
||||||
|
audio_data = sample["audio"]
|
||||||
|
audio_array = audio_data["array"]
|
||||||
|
sample_rate = audio_data["sampling_rate"]
|
||||||
|
|
||||||
|
# Convert audio to 16kHz if needed
|
||||||
|
if sample_rate != 16000:
|
||||||
|
num_samples = int(len(audio_array) * (16000 / sample_rate))
|
||||||
|
audio_array = resample(audio_array, num_samples)
|
||||||
|
|
||||||
|
# Convert to torch tensor
|
||||||
|
prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0)
|
||||||
|
prompt_text = sample["text"]
|
||||||
|
# remove space in prompt_text
|
||||||
|
prompt_text = prompt_text.replace(" ", "")
|
||||||
|
return prompt_text, prompt_speech_16k
|
||||||
|
|
||||||
|
|
||||||
|
class _Token2Wav_ASR:
|
||||||
|
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||||
|
|
||||||
|
def __init__(self, device_id: int):
|
||||||
|
self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
|
||||||
|
self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"]
|
||||||
|
|
||||||
|
# Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model
|
||||||
|
# CosyVoice2 internally uses generic "cuda" device, so we first switch the
|
||||||
|
# current CUDA context to the desired card before the object is created.
|
||||||
|
# Afterwards, all parameters loaded with the generic "cuda" device will
|
||||||
|
# reside on this GPU. We keep the selected id in `self.device_id` and
|
||||||
|
# will set the context again for every forward call to avoid race
|
||||||
|
# conditions when several instances are used in the same process.
|
||||||
|
|
||||||
|
self.device_id = device_id
|
||||||
|
|
||||||
|
# Construct the TTS codec decoder under the correct CUDA device context
|
||||||
|
with torch.cuda.device(self.device_id):
|
||||||
|
self.codec_decoder = CosyVoice2(
|
||||||
|
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@batch
|
||||||
|
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
|
||||||
|
"""
|
||||||
|
WAV: np.ndarray, WAV_LENS: np.ndarray
|
||||||
|
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
|
||||||
|
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
|
||||||
|
"""
|
||||||
|
# Ensure the default CUDA device is set correctly for this invocation
|
||||||
|
torch.cuda.set_device(self.device_id)
|
||||||
|
|
||||||
|
if self.device_id == 0:
|
||||||
|
print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
|
||||||
|
|
||||||
|
tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
|
||||||
|
|
||||||
|
# Decode ground-truth text strings (BYTES → str)
|
||||||
|
if GT_TEXT.ndim == 2:
|
||||||
|
gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
|
||||||
|
else:
|
||||||
|
gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
|
||||||
|
|
||||||
|
wavs = []
|
||||||
|
for tokens in tokens_list:
|
||||||
|
prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
|
||||||
|
audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
|
||||||
|
audio_hat = audio_decode_cosyvoice2(
|
||||||
|
audio_tokens,
|
||||||
|
prompt_text,
|
||||||
|
prompt_speech_16k,
|
||||||
|
self.codec_decoder,
|
||||||
|
)
|
||||||
|
# resample to 16000 using soundfile
|
||||||
|
audio_hat = audio_hat.squeeze(0).float().cpu()
|
||||||
|
audio_hat = audio_hat.numpy()
|
||||||
|
num_samples = int(len(audio_hat) * (16000 / 24000))
|
||||||
|
audio_hat = resample(audio_hat, num_samples)
|
||||||
|
wavs.append(audio_hat)
|
||||||
|
|
||||||
|
results = self.asr_model.transcribe_single_batch(
|
||||||
|
wavs,
|
||||||
|
language="zh",
|
||||||
|
textnorm="woitn",
|
||||||
|
)
|
||||||
|
texts = [result.text for result in results]
|
||||||
|
|
||||||
|
# ---------------- Reward computation ----------------
|
||||||
|
rewards = []
|
||||||
|
for gt_text, hyp_text in zip(gt_texts, texts):
|
||||||
|
gt_norm = zh_tn_model.normalize(gt_text).lower()
|
||||||
|
hyp_norm = zh_tn_model.normalize(hyp_text).lower()
|
||||||
|
|
||||||
|
gt_pinyin = lazy_pinyin(
|
||||||
|
gt_norm,
|
||||||
|
style=Style.TONE3,
|
||||||
|
tone_sandhi=True,
|
||||||
|
neutral_tone_with_five=True,
|
||||||
|
)
|
||||||
|
hyp_pinyin = lazy_pinyin(
|
||||||
|
hyp_norm,
|
||||||
|
style=Style.TONE3,
|
||||||
|
tone_sandhi=True,
|
||||||
|
neutral_tone_with_five=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin)))
|
||||||
|
reward_val = 1.0 - np.tanh(3.0 * c)
|
||||||
|
reward_val = max(0.0, min(1.0, reward_val))
|
||||||
|
rewards.append(reward_val)
|
||||||
|
print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}")
|
||||||
|
|
||||||
|
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||||
|
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
|
||||||
|
|
||||||
|
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_function_factory(device_ids: List[int], model_name: str):
|
||||||
|
"""Creates a list of inference functions, one for each requested device ID."""
|
||||||
|
infer_funcs = []
|
||||||
|
for device_id in device_ids:
|
||||||
|
if model_name == "sensevoice":
|
||||||
|
infer_funcs.append(_ASR_Server(device_id=device_id))
|
||||||
|
else:
|
||||||
|
infer_funcs.append(_Token2Wav_ASR(device_id=device_id))
|
||||||
|
return infer_funcs
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size of request.",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--number-of-instances-per-device",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of model instances to load.",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--number-of-devices",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of devices to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="token2wav_asr",
|
||||||
|
choices=["token2wav_asr", "sensevoice"],
|
||||||
|
help="Model name.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
log_level = logging.DEBUG if args.verbose else logging.INFO
|
||||||
|
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
|
||||||
|
|
||||||
|
triton_config = TritonConfig(
|
||||||
|
http_port=8000,
|
||||||
|
grpc_port=8001,
|
||||||
|
metrics_port=8002,
|
||||||
|
)
|
||||||
|
|
||||||
|
device_ids = list(range(args.number_of_devices))
|
||||||
|
device_ids = device_ids * args.number_of_instances_per_device
|
||||||
|
|
||||||
|
with Triton(config=triton_config) as triton:
|
||||||
|
logger.info("Loading SenseVoice model on device ids: %s", device_ids)
|
||||||
|
if args.model_name == "sensevoice":
|
||||||
|
triton.bind(
|
||||||
|
model_name="sensevoice",
|
||||||
|
infer_func=_infer_function_factory(device_ids, args.model_name),
|
||||||
|
inputs=[
|
||||||
|
Tensor(name="WAV", dtype=np.float32, shape=(-1,)),
|
||||||
|
Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)),
|
||||||
|
Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)),
|
||||||
|
Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
|
||||||
|
],
|
||||||
|
config=ModelConfig(
|
||||||
|
max_batch_size=args.max_batch_size,
|
||||||
|
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
|
||||||
|
),
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
triton.bind(
|
||||||
|
model_name="token2wav_asr",
|
||||||
|
infer_func=_infer_function_factory(device_ids, args.model_name),
|
||||||
|
inputs=[
|
||||||
|
Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)),
|
||||||
|
Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)),
|
||||||
|
Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)),
|
||||||
|
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
|
||||||
|
],
|
||||||
|
config=ModelConfig(
|
||||||
|
max_batch_size=args.max_batch_size,
|
||||||
|
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
|
||||||
|
),
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
logger.info("Serving inference")
|
||||||
|
triton.serve()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,198 +0,0 @@
|
|||||||
# set random seed, so that you may reproduce your result.
|
|
||||||
__set_seed1: !apply:random.seed [1986]
|
|
||||||
__set_seed2: !apply:numpy.random.seed [1986]
|
|
||||||
__set_seed3: !apply:torch.manual_seed [1986]
|
|
||||||
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
|
||||||
|
|
||||||
# fixed params
|
|
||||||
sample_rate: 22050
|
|
||||||
text_encoder_input_size: 512
|
|
||||||
llm_input_size: 1024
|
|
||||||
llm_output_size: 1024
|
|
||||||
spk_embed_dim: 192
|
|
||||||
|
|
||||||
# model params
|
|
||||||
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
|
||||||
# for system/third_party class/function, we do not require this.
|
|
||||||
llm: !new:cosyvoice.llm.llm.TransformerLM
|
|
||||||
text_encoder_input_size: !ref <text_encoder_input_size>
|
|
||||||
llm_input_size: !ref <llm_input_size>
|
|
||||||
llm_output_size: !ref <llm_output_size>
|
|
||||||
text_token_size: 51866
|
|
||||||
speech_token_size: 4096
|
|
||||||
length_normalized_loss: True
|
|
||||||
lsm_weight: 0
|
|
||||||
spk_embed_dim: !ref <spk_embed_dim>
|
|
||||||
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
|
||||||
input_size: !ref <text_encoder_input_size>
|
|
||||||
output_size: 1024
|
|
||||||
attention_heads: 8
|
|
||||||
linear_units: 2048
|
|
||||||
num_blocks: 3
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0
|
|
||||||
normalize_before: True
|
|
||||||
input_layer: 'linear'
|
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
use_cnn_module: False
|
|
||||||
macaron_style: False
|
|
||||||
use_dynamic_chunk: False
|
|
||||||
use_dynamic_left_chunk: False
|
|
||||||
static_chunk_size: 1
|
|
||||||
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
|
|
||||||
input_size: !ref <llm_input_size>
|
|
||||||
output_size: !ref <llm_output_size>
|
|
||||||
attention_heads: 8
|
|
||||||
linear_units: 2048
|
|
||||||
num_blocks: 7
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0
|
|
||||||
input_layer: 'linear_legacy'
|
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
static_chunk_size: 1
|
|
||||||
|
|
||||||
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
|
||||||
input_size: 512
|
|
||||||
output_size: 80
|
|
||||||
spk_embed_dim: !ref <spk_embed_dim>
|
|
||||||
output_type: 'mel'
|
|
||||||
vocab_size: 4096
|
|
||||||
input_frame_rate: 50
|
|
||||||
only_mask_loss: True
|
|
||||||
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
|
||||||
output_size: 512
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 1024
|
|
||||||
num_blocks: 3
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.1
|
|
||||||
normalize_before: True
|
|
||||||
input_layer: 'linear'
|
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
input_size: 512
|
|
||||||
use_cnn_module: False
|
|
||||||
macaron_style: False
|
|
||||||
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
|
|
||||||
channels: 80
|
|
||||||
sampling_ratios: [1, 1, 1, 1]
|
|
||||||
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
|
|
||||||
in_channels: 240
|
|
||||||
n_spks: 1
|
|
||||||
spk_emb_dim: 80
|
|
||||||
cfm_params: !new:omegaconf.DictConfig
|
|
||||||
content:
|
|
||||||
sigma_min: 1e-06
|
|
||||||
solver: 'euler'
|
|
||||||
t_scheduler: 'cosine'
|
|
||||||
training_cfg_rate: 0.2
|
|
||||||
inference_cfg_rate: 0.7
|
|
||||||
reg_loss_type: 'l1'
|
|
||||||
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
|
|
||||||
in_channels: 320
|
|
||||||
out_channels: 80
|
|
||||||
channels: [256, 256]
|
|
||||||
dropout: 0
|
|
||||||
attention_head_dim: 64
|
|
||||||
n_blocks: 4
|
|
||||||
num_mid_blocks: 8
|
|
||||||
num_heads: 8
|
|
||||||
act_fn: 'gelu'
|
|
||||||
|
|
||||||
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
|
||||||
in_channels: 80
|
|
||||||
base_channels: 512
|
|
||||||
nb_harmonics: 8
|
|
||||||
sampling_rate: !ref <sample_rate>
|
|
||||||
nsf_alpha: 0.1
|
|
||||||
nsf_sigma: 0.003
|
|
||||||
nsf_voiced_threshold: 10
|
|
||||||
upsample_rates: [8, 8]
|
|
||||||
upsample_kernel_sizes: [16, 16]
|
|
||||||
istft_params:
|
|
||||||
n_fft: 16
|
|
||||||
hop_len: 4
|
|
||||||
resblock_kernel_sizes: [3, 7, 11]
|
|
||||||
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
|
||||||
source_resblock_kernel_sizes: [7, 11]
|
|
||||||
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
|
|
||||||
lrelu_slope: 0.1
|
|
||||||
audio_limit: 0.99
|
|
||||||
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
|
||||||
num_class: 1
|
|
||||||
in_channels: 80
|
|
||||||
cond_channels: 512
|
|
||||||
|
|
||||||
# processor functions
|
|
||||||
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
|
||||||
get_tokenizer: !name:whisper.tokenizer.get_tokenizer
|
|
||||||
multilingual: True
|
|
||||||
num_languages: 100
|
|
||||||
language: 'en'
|
|
||||||
task: 'transcribe'
|
|
||||||
allowed_special: 'all'
|
|
||||||
tokenize: !name:cosyvoice.dataset.processor.tokenize
|
|
||||||
get_tokenizer: !ref <get_tokenizer>
|
|
||||||
allowed_special: !ref <allowed_special>
|
|
||||||
filter: !name:cosyvoice.dataset.processor.filter
|
|
||||||
max_length: 40960
|
|
||||||
min_length: 0
|
|
||||||
token_max_length: 200
|
|
||||||
token_min_length: 1
|
|
||||||
resample: !name:cosyvoice.dataset.processor.resample
|
|
||||||
resample_rate: !ref <sample_rate>
|
|
||||||
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
|
||||||
n_fft: 1024
|
|
||||||
num_mels: 80
|
|
||||||
sampling_rate: !ref <sample_rate>
|
|
||||||
hop_size: 256
|
|
||||||
win_size: 1024
|
|
||||||
fmin: 0
|
|
||||||
fmax: 8000
|
|
||||||
center: False
|
|
||||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
|
||||||
feat_extractor: !ref <feat_extractor>
|
|
||||||
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
|
||||||
normalize: True
|
|
||||||
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
|
||||||
shuffle_size: 1000
|
|
||||||
sort: !name:cosyvoice.dataset.processor.sort
|
|
||||||
sort_size: 500 # sort_size should be less than shuffle_size
|
|
||||||
batch: !name:cosyvoice.dataset.processor.batch
|
|
||||||
batch_type: 'dynamic'
|
|
||||||
max_frames_in_batch: 12000
|
|
||||||
padding: !name:cosyvoice.dataset.processor.padding
|
|
||||||
use_spk_embedding: False # change to True during sft
|
|
||||||
|
|
||||||
# dataset processor pipeline
|
|
||||||
data_pipeline: [
|
|
||||||
!ref <parquet_opener>,
|
|
||||||
!ref <tokenize>,
|
|
||||||
!ref <filter>,
|
|
||||||
!ref <resample>,
|
|
||||||
!ref <compute_fbank>,
|
|
||||||
!ref <parse_embedding>,
|
|
||||||
!ref <shuffle>,
|
|
||||||
!ref <sort>,
|
|
||||||
!ref <batch>,
|
|
||||||
!ref <padding>,
|
|
||||||
]
|
|
||||||
|
|
||||||
# train conf
|
|
||||||
train_conf:
|
|
||||||
optim: adam
|
|
||||||
optim_conf:
|
|
||||||
lr: 0.002 # change to 0.001 if you want to train flow from scratch
|
|
||||||
scheduler: warmuplr
|
|
||||||
scheduler_conf:
|
|
||||||
warmup_steps: 25000
|
|
||||||
max_epoch: 200
|
|
||||||
grad_clip: 5
|
|
||||||
accum_grad: 2
|
|
||||||
log_interval: 100
|
|
||||||
save_per_step: -1
|
|
||||||
@@ -18,7 +18,7 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
|
|||||||
text_encoder_input_size: !ref <text_encoder_input_size>
|
text_encoder_input_size: !ref <text_encoder_input_size>
|
||||||
llm_input_size: !ref <llm_input_size>
|
llm_input_size: !ref <llm_input_size>
|
||||||
llm_output_size: !ref <llm_output_size>
|
llm_output_size: !ref <llm_output_size>
|
||||||
text_token_size: 51866
|
text_token_size: 51866 # change to 60515 if you want to train with CosyVoice-300M-25Hz recipe
|
||||||
speech_token_size: 4096
|
speech_token_size: 4096
|
||||||
length_normalized_loss: True
|
length_normalized_loss: True
|
||||||
lsm_weight: 0
|
lsm_weight: 0
|
||||||
@@ -31,7 +31,7 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
|
|||||||
num_blocks: 6
|
num_blocks: 6
|
||||||
dropout_rate: 0.1
|
dropout_rate: 0.1
|
||||||
positional_dropout_rate: 0.1
|
positional_dropout_rate: 0.1
|
||||||
attention_dropout_rate: 0
|
attention_dropout_rate: 0.0
|
||||||
normalize_before: True
|
normalize_before: True
|
||||||
input_layer: 'linear'
|
input_layer: 'linear'
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
pos_enc_layer_type: 'rel_pos_espnet'
|
||||||
@@ -49,11 +49,16 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
|
|||||||
num_blocks: 14
|
num_blocks: 14
|
||||||
dropout_rate: 0.1
|
dropout_rate: 0.1
|
||||||
positional_dropout_rate: 0.1
|
positional_dropout_rate: 0.1
|
||||||
attention_dropout_rate: 0
|
attention_dropout_rate: 0.0
|
||||||
input_layer: 'linear_legacy'
|
input_layer: 'linear_legacy'
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
pos_enc_layer_type: 'rel_pos_espnet'
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
selfattention_layer_type: 'rel_selfattn'
|
||||||
static_chunk_size: 1
|
static_chunk_size: 1
|
||||||
|
sampling: !name:cosyvoice.utils.common.ras_sampling
|
||||||
|
top_p: 0.8
|
||||||
|
top_k: 25
|
||||||
|
win_size: 10
|
||||||
|
tau_r: 0.1
|
||||||
|
|
||||||
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
||||||
input_size: 512
|
input_size: 512
|
||||||
@@ -61,7 +66,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
|||||||
spk_embed_dim: !ref <spk_embed_dim>
|
spk_embed_dim: !ref <spk_embed_dim>
|
||||||
output_type: 'mel'
|
output_type: 'mel'
|
||||||
vocab_size: 4096
|
vocab_size: 4096
|
||||||
input_frame_rate: 50
|
input_frame_rate: 50 # change to 25 if you want to train with CosyVoice-300M-25Hz recipe
|
||||||
only_mask_loss: True
|
only_mask_loss: True
|
||||||
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
||||||
output_size: 512
|
output_size: 512
|
||||||
@@ -97,7 +102,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
|||||||
in_channels: 320
|
in_channels: 320
|
||||||
out_channels: 80
|
out_channels: 80
|
||||||
channels: [256, 256]
|
channels: [256, 256]
|
||||||
dropout: 0
|
dropout: 0.0
|
||||||
attention_head_dim: 64
|
attention_head_dim: 64
|
||||||
n_blocks: 4
|
n_blocks: 4
|
||||||
num_mid_blocks: 12
|
num_mid_blocks: 12
|
||||||
@@ -128,9 +133,28 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
|||||||
in_channels: 80
|
in_channels: 80
|
||||||
cond_channels: 512
|
cond_channels: 512
|
||||||
|
|
||||||
|
# gan related module
|
||||||
|
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
|
||||||
|
n_fft: 1024
|
||||||
|
num_mels: 80
|
||||||
|
sampling_rate: !ref <sample_rate>
|
||||||
|
hop_size: 256
|
||||||
|
win_size: 1024
|
||||||
|
fmin: 0
|
||||||
|
fmax: null
|
||||||
|
center: False
|
||||||
|
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
|
||||||
|
generator: !ref <hift>
|
||||||
|
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
|
||||||
|
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
|
||||||
|
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
|
||||||
|
mel_spec_transform: [
|
||||||
|
!ref <mel_spec_transform1>
|
||||||
|
]
|
||||||
|
|
||||||
# processor functions
|
# processor functions
|
||||||
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
||||||
get_tokenizer: !name:whisper.tokenizer.get_tokenizer
|
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
|
||||||
multilingual: True
|
multilingual: True
|
||||||
num_languages: 100
|
num_languages: 100
|
||||||
language: 'en'
|
language: 'en'
|
||||||
@@ -146,6 +170,8 @@ filter: !name:cosyvoice.dataset.processor.filter
|
|||||||
token_min_length: 1
|
token_min_length: 1
|
||||||
resample: !name:cosyvoice.dataset.processor.resample
|
resample: !name:cosyvoice.dataset.processor.resample
|
||||||
resample_rate: !ref <sample_rate>
|
resample_rate: !ref <sample_rate>
|
||||||
|
truncate: !name:cosyvoice.dataset.processor.truncate
|
||||||
|
truncate_length: 24576 # must be a multiplier of hop_size
|
||||||
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||||
n_fft: 1024
|
n_fft: 1024
|
||||||
num_mels: 80
|
num_mels: 80
|
||||||
@@ -157,6 +183,9 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
|||||||
center: False
|
center: False
|
||||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
||||||
feat_extractor: !ref <feat_extractor>
|
feat_extractor: !ref <feat_extractor>
|
||||||
|
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
||||||
|
sample_rate: !ref <sample_rate>
|
||||||
|
hop_size: 256
|
||||||
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
||||||
normalize: True
|
normalize: True
|
||||||
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
||||||
@@ -165,7 +194,7 @@ sort: !name:cosyvoice.dataset.processor.sort
|
|||||||
sort_size: 500 # sort_size should be less than shuffle_size
|
sort_size: 500 # sort_size should be less than shuffle_size
|
||||||
batch: !name:cosyvoice.dataset.processor.batch
|
batch: !name:cosyvoice.dataset.processor.batch
|
||||||
batch_type: 'dynamic'
|
batch_type: 'dynamic'
|
||||||
max_frames_in_batch: 2000
|
max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g
|
||||||
padding: !name:cosyvoice.dataset.processor.padding
|
padding: !name:cosyvoice.dataset.processor.padding
|
||||||
use_spk_embedding: False # change to True during sft
|
use_spk_embedding: False # change to True during sft
|
||||||
|
|
||||||
@@ -182,8 +211,22 @@ data_pipeline: [
|
|||||||
!ref <batch>,
|
!ref <batch>,
|
||||||
!ref <padding>,
|
!ref <padding>,
|
||||||
]
|
]
|
||||||
|
data_pipeline_gan: [
|
||||||
|
!ref <parquet_opener>,
|
||||||
|
!ref <tokenize>,
|
||||||
|
!ref <filter>,
|
||||||
|
!ref <resample>,
|
||||||
|
!ref <truncate>,
|
||||||
|
!ref <compute_fbank>,
|
||||||
|
!ref <compute_f0>,
|
||||||
|
!ref <parse_embedding>,
|
||||||
|
!ref <shuffle>,
|
||||||
|
!ref <sort>,
|
||||||
|
!ref <batch>,
|
||||||
|
!ref <padding>,
|
||||||
|
]
|
||||||
|
|
||||||
# train conf
|
# llm flow train conf
|
||||||
train_conf:
|
train_conf:
|
||||||
optim: adam
|
optim: adam
|
||||||
optim_conf:
|
optim_conf:
|
||||||
@@ -195,4 +238,20 @@ train_conf:
|
|||||||
grad_clip: 5
|
grad_clip: 5
|
||||||
accum_grad: 2
|
accum_grad: 2
|
||||||
log_interval: 100
|
log_interval: 100
|
||||||
|
save_per_step: -1
|
||||||
|
|
||||||
|
# gan train conf
|
||||||
|
train_conf_gan:
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.0002 # use small lr for gan training
|
||||||
|
scheduler: constantlr
|
||||||
|
optim_d: adam
|
||||||
|
optim_conf_d:
|
||||||
|
lr: 0.0002 # use small lr for gan training
|
||||||
|
scheduler_d: constantlr
|
||||||
|
max_epoch: 200
|
||||||
|
grad_clip: 5
|
||||||
|
accum_grad: 1 # in gan training, accum_grad must be 1
|
||||||
|
log_interval: 100
|
||||||
save_per_step: -1
|
save_per_step: -1
|
||||||
@@ -1 +0,0 @@
|
|||||||
../../../cosyvoice
|
|
||||||
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
|
wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
|
||||||
|
|
||||||
@@ -39,13 +40,21 @@ def main():
|
|||||||
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
|
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
|
||||||
for k, v in spk2utt.items():
|
for k, v in spk2utt.items():
|
||||||
f.write('{} {}\n'.format(k, ' '.join(v)))
|
f.write('{} {}\n'.format(k, ' '.join(v)))
|
||||||
|
if args.instruct != '':
|
||||||
|
with open('{}/instruct'.format(args.des_dir), 'w') as f:
|
||||||
|
for k, v in utt2text.items():
|
||||||
|
f.write('{} {}\n'.format(k, args.instruct))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--src_dir',
|
parser.add_argument('--src_dir',
|
||||||
type=str)
|
type=str)
|
||||||
parser.add_argument('--des_dir',
|
parser.add_argument('--des_dir',
|
||||||
type=str)
|
type=str)
|
||||||
|
parser.add_argument('--instruct',
|
||||||
|
type=str,
|
||||||
|
default='')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main()
|
main()
|
||||||
|
|||||||
50
examples/libritts/cosyvoice/local/prepare_reject_sample.py
Normal file
50
examples/libritts/cosyvoice/local/prepare_reject_sample.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
from cosyvoice.utils.file_utils import load_wav
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cosyvoice = CosyVoice2(args.ref_model)
|
||||||
|
|
||||||
|
utt2wav, utt2text = {}, {}
|
||||||
|
with open('{}/wav.scp'.format(args.src_dir)) as f:
|
||||||
|
for l in f:
|
||||||
|
l = l.split('\n')[0].split()
|
||||||
|
utt2wav[l[0]] = l[1]
|
||||||
|
with open('{}/text'.format(args.src_dir)) as f:
|
||||||
|
for l in f:
|
||||||
|
l = l.split('\n')[0].split()
|
||||||
|
utt2text[l[0]] = ' '.join(l[1:])
|
||||||
|
|
||||||
|
os.makedirs('{}/wav'.format(args.des_dir), exist_ok=True)
|
||||||
|
with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
|
||||||
|
for utt, wav in tqdm(utt2wav.items()):
|
||||||
|
prompt_speech_16k = load_wav(wav, 16000)
|
||||||
|
if prompt_speech_16k.shape[1] >= 30 * 16000:
|
||||||
|
continue
|
||||||
|
speech_list = []
|
||||||
|
for _, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)):
|
||||||
|
speech_list.append(j['tts_speech'])
|
||||||
|
negative_wav = os.path.abspath('{}/wav/{}'.format(args.des_dir, os.path.basename(wav)))
|
||||||
|
torchaudio.save(negative_wav, torch.concat(speech_list, dim=1), cosyvoice.sample_rate, backend='soundfile')
|
||||||
|
f.write('{} {}\n'.format(utt, negative_wav))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--src_dir',
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--des_dir',
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--ref_model',
|
||||||
|
type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main()
|
||||||
@@ -27,7 +27,7 @@ fi
|
|||||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
||||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
tools/extract_embedding.py --dir data/$x \
|
../../../tools/extract_embedding.py --dir data/$x \
|
||||||
--onnx_path $pretrained_model_dir/campplus.onnx
|
--onnx_path $pretrained_model_dir/campplus.onnx
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
@@ -35,7 +35,7 @@ fi
|
|||||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
||||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
tools/extract_speech_token.py --dir data/$x \
|
../../../tools/extract_speech_token.py --dir data/$x \
|
||||||
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
|
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
@@ -44,30 +44,13 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||||
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
mkdir -p data/$x/parquet
|
mkdir -p data/$x/parquet
|
||||||
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||||
--num_processes 10 \
|
--num_processes 10 \
|
||||||
--src_dir data/$x \
|
--src_dir data/$x \
|
||||||
--des_dir data/$x/parquet
|
--des_dir data/$x/parquet
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# inference
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|
||||||
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
|
|
||||||
for mode in sft zero_shot; do
|
|
||||||
python cosyvoice/bin/inference.py --mode $mode \
|
|
||||||
--gpu 0 \
|
|
||||||
--config conf/cosyvoice.yaml \
|
|
||||||
--prompt_data data/test-clean/parquet/data.list \
|
|
||||||
--prompt_utt2data data/test-clean/parquet/utt2data.list \
|
|
||||||
--tts_text `pwd`/tts_text.json \
|
|
||||||
--llm_model $pretrained_model_dir/llm.pt \
|
|
||||||
--flow_model $pretrained_model_dir/flow.pt \
|
|
||||||
--hifigan_model $pretrained_model_dir/hift.pt \
|
|
||||||
--result_dir `pwd`/exp/cosyvoice/test-clean/$mode
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
# train llm
|
# train llm
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
@@ -77,16 +60,16 @@ num_workers=2
|
|||||||
prefetch=100
|
prefetch=100
|
||||||
train_engine=torch_ddp
|
train_engine=torch_ddp
|
||||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||||
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
|
echo "Run train. We only support llm traning for now"
|
||||||
if [ $train_engine == 'deepspeed' ]; then
|
if [ $train_engine == 'deepspeed' ]; then
|
||||||
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
||||||
fi
|
fi
|
||||||
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
||||||
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
||||||
for model in llm; do
|
for model in llm flow hifigan; do
|
||||||
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||||
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
||||||
cosyvoice/bin/train.py \
|
../../../cosyvoice/bin/train.py \
|
||||||
--train_engine $train_engine \
|
--train_engine $train_engine \
|
||||||
--config conf/cosyvoice.yaml \
|
--config conf/cosyvoice.yaml \
|
||||||
--train_data data/train.data.list \
|
--train_data data/train.data.list \
|
||||||
@@ -99,7 +82,28 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|||||||
--num_workers ${num_workers} \
|
--num_workers ${num_workers} \
|
||||||
--prefetch ${prefetch} \
|
--prefetch ${prefetch} \
|
||||||
--pin_memory \
|
--pin_memory \
|
||||||
|
--use_amp \
|
||||||
--deepspeed_config ./conf/ds_stage2.json \
|
--deepspeed_config ./conf/ds_stage2.json \
|
||||||
--deepspeed.save_states model+optimizer
|
--deepspeed.save_states model+optimizer
|
||||||
done
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# average model
|
||||||
|
average_num=5
|
||||||
|
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||||
|
for model in llm flow hifigan; do
|
||||||
|
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
|
||||||
|
echo "do model average and final checkpoint is $decode_checkpoint"
|
||||||
|
python cosyvoice/bin/average_model.py \
|
||||||
|
--dst_model $decode_checkpoint \
|
||||||
|
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
|
||||||
|
--num ${average_num} \
|
||||||
|
--val_best
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
||||||
|
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
||||||
|
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
||||||
|
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
||||||
fi
|
fi
|
||||||
@@ -1 +0,0 @@
|
|||||||
../../../tools
|
|
||||||
@@ -5,65 +5,47 @@ __set_seed3: !apply:torch.manual_seed [1986]
|
|||||||
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
||||||
|
|
||||||
# fixed params
|
# fixed params
|
||||||
sample_rate: 22050
|
sample_rate: 24000
|
||||||
text_encoder_input_size: 512
|
llm_input_size: 896
|
||||||
llm_input_size: 1024
|
llm_output_size: 896
|
||||||
llm_output_size: 1024
|
|
||||||
spk_embed_dim: 192
|
spk_embed_dim: 192
|
||||||
|
qwen_pretrain_path: ''
|
||||||
|
token_frame_rate: 25
|
||||||
|
token_mel_ratio: 2
|
||||||
|
|
||||||
|
# stream related params
|
||||||
|
chunk_size: 25 # streaming inference chunk size, in token
|
||||||
|
num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
|
||||||
|
|
||||||
# model params
|
# model params
|
||||||
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
||||||
# for system/third_party class/function, we do not require this.
|
# for system/third_party class/function, we do not require this.
|
||||||
llm: !new:cosyvoice.llm.llm.TransformerLM
|
llm: !new:cosyvoice.llm.llm.Qwen2LM
|
||||||
text_encoder_input_size: !ref <text_encoder_input_size>
|
|
||||||
llm_input_size: !ref <llm_input_size>
|
llm_input_size: !ref <llm_input_size>
|
||||||
llm_output_size: !ref <llm_output_size>
|
llm_output_size: !ref <llm_output_size>
|
||||||
text_token_size: 51866
|
speech_token_size: 6561
|
||||||
speech_token_size: 4096
|
|
||||||
length_normalized_loss: True
|
length_normalized_loss: True
|
||||||
lsm_weight: 0
|
lsm_weight: 0
|
||||||
spk_embed_dim: !ref <spk_embed_dim>
|
mix_ratio: [5, 15]
|
||||||
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
llm: !new:cosyvoice.llm.llm.Qwen2Encoder
|
||||||
input_size: !ref <text_encoder_input_size>
|
pretrain_path: !ref <qwen_pretrain_path>
|
||||||
output_size: 1024
|
sampling: !name:cosyvoice.utils.common.ras_sampling
|
||||||
attention_heads: 16
|
top_p: 0.8
|
||||||
linear_units: 4096
|
top_k: 25
|
||||||
num_blocks: 6
|
win_size: 10
|
||||||
dropout_rate: 0.1
|
tau_r: 0.1
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.0
|
|
||||||
normalize_before: True
|
|
||||||
input_layer: 'linear'
|
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
use_cnn_module: False
|
|
||||||
macaron_style: False
|
|
||||||
use_dynamic_chunk: False
|
|
||||||
use_dynamic_left_chunk: False
|
|
||||||
static_chunk_size: 1
|
|
||||||
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
|
|
||||||
input_size: !ref <llm_input_size>
|
|
||||||
output_size: !ref <llm_output_size>
|
|
||||||
attention_heads: 16
|
|
||||||
linear_units: 4096
|
|
||||||
num_blocks: 14
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.0
|
|
||||||
input_layer: 'linear_legacy'
|
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
static_chunk_size: 1
|
|
||||||
|
|
||||||
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
|
||||||
input_size: 512
|
input_size: 512
|
||||||
output_size: 80
|
output_size: 80
|
||||||
spk_embed_dim: !ref <spk_embed_dim>
|
spk_embed_dim: !ref <spk_embed_dim>
|
||||||
output_type: 'mel'
|
output_type: 'mel'
|
||||||
vocab_size: 4096
|
vocab_size: 6561
|
||||||
input_frame_rate: 50
|
input_frame_rate: !ref <token_frame_rate>
|
||||||
only_mask_loss: True
|
only_mask_loss: True
|
||||||
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
token_mel_ratio: !ref <token_mel_ratio>
|
||||||
|
pre_lookahead_len: 3
|
||||||
|
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
|
||||||
output_size: 512
|
output_size: 512
|
||||||
attention_heads: 8
|
attention_heads: 8
|
||||||
linear_units: 2048
|
linear_units: 2048
|
||||||
@@ -78,10 +60,8 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
|||||||
input_size: 512
|
input_size: 512
|
||||||
use_cnn_module: False
|
use_cnn_module: False
|
||||||
macaron_style: False
|
macaron_style: False
|
||||||
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
|
static_chunk_size: !ref <chunk_size>
|
||||||
channels: 80
|
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
|
||||||
sampling_ratios: [1, 1, 1, 1]
|
|
||||||
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
|
|
||||||
in_channels: 240
|
in_channels: 240
|
||||||
n_spks: 1
|
n_spks: 1
|
||||||
spk_emb_dim: 80
|
spk_emb_dim: 80
|
||||||
@@ -93,16 +73,18 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
|||||||
training_cfg_rate: 0.2
|
training_cfg_rate: 0.2
|
||||||
inference_cfg_rate: 0.7
|
inference_cfg_rate: 0.7
|
||||||
reg_loss_type: 'l1'
|
reg_loss_type: 'l1'
|
||||||
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
|
estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
|
||||||
in_channels: 320
|
in_channels: 320
|
||||||
out_channels: 80
|
out_channels: 80
|
||||||
channels: [256, 256]
|
channels: [256]
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
attention_head_dim: 64
|
attention_head_dim: 64
|
||||||
n_blocks: 4
|
n_blocks: 4
|
||||||
num_mid_blocks: 12
|
num_mid_blocks: 12
|
||||||
num_heads: 8
|
num_heads: 8
|
||||||
act_fn: 'gelu'
|
act_fn: 'gelu'
|
||||||
|
static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
|
||||||
|
num_decoding_left_chunks: !ref <num_decoding_left_chunks>
|
||||||
|
|
||||||
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
||||||
in_channels: 80
|
in_channels: 80
|
||||||
@@ -112,15 +94,15 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
|||||||
nsf_alpha: 0.1
|
nsf_alpha: 0.1
|
||||||
nsf_sigma: 0.003
|
nsf_sigma: 0.003
|
||||||
nsf_voiced_threshold: 10
|
nsf_voiced_threshold: 10
|
||||||
upsample_rates: [8, 8]
|
upsample_rates: [8, 5, 3]
|
||||||
upsample_kernel_sizes: [16, 16]
|
upsample_kernel_sizes: [16, 11, 7]
|
||||||
istft_params:
|
istft_params:
|
||||||
n_fft: 16
|
n_fft: 16
|
||||||
hop_len: 4
|
hop_len: 4
|
||||||
resblock_kernel_sizes: [3, 7, 11]
|
resblock_kernel_sizes: [3, 7, 11]
|
||||||
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||||
source_resblock_kernel_sizes: [7, 11]
|
source_resblock_kernel_sizes: [7, 7, 11]
|
||||||
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
|
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||||
lrelu_slope: 0.1
|
lrelu_slope: 0.1
|
||||||
audio_limit: 0.99
|
audio_limit: 0.99
|
||||||
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
||||||
@@ -128,35 +110,60 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
|||||||
in_channels: 80
|
in_channels: 80
|
||||||
cond_channels: 512
|
cond_channels: 512
|
||||||
|
|
||||||
|
# gan related module
|
||||||
|
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
|
||||||
|
n_fft: 1920
|
||||||
|
num_mels: 80
|
||||||
|
sampling_rate: !ref <sample_rate>
|
||||||
|
hop_size: 480
|
||||||
|
win_size: 1920
|
||||||
|
fmin: 0
|
||||||
|
fmax: null
|
||||||
|
center: False
|
||||||
|
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
|
||||||
|
generator: !ref <hift>
|
||||||
|
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
|
||||||
|
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
|
||||||
|
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
|
||||||
|
mel_spec_transform: [
|
||||||
|
!ref <mel_spec_transform1>
|
||||||
|
]
|
||||||
|
|
||||||
# processor functions
|
# processor functions
|
||||||
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
||||||
get_tokenizer: !name:whisper.tokenizer.get_tokenizer
|
get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
|
||||||
multilingual: True
|
token_path: !ref <qwen_pretrain_path>
|
||||||
num_languages: 100
|
skip_special_tokens: True
|
||||||
language: 'en'
|
|
||||||
task: 'transcribe'
|
|
||||||
allowed_special: 'all'
|
allowed_special: 'all'
|
||||||
tokenize: !name:cosyvoice.dataset.processor.tokenize
|
tokenize: !name:cosyvoice.dataset.processor.tokenize
|
||||||
get_tokenizer: !ref <get_tokenizer>
|
get_tokenizer: !ref <get_tokenizer>
|
||||||
allowed_special: !ref <allowed_special>
|
allowed_special: !ref <allowed_special>
|
||||||
filter: !name:cosyvoice.dataset.processor.filter
|
filter: !name:cosyvoice.dataset.processor.filter
|
||||||
max_length: 40960
|
max_length: 6000
|
||||||
min_length: 0
|
min_length: 100
|
||||||
token_max_length: 200
|
token_max_length: 200
|
||||||
token_min_length: 1
|
token_min_length: 1
|
||||||
resample: !name:cosyvoice.dataset.processor.resample
|
resample: !name:cosyvoice.dataset.processor.resample
|
||||||
resample_rate: !ref <sample_rate>
|
resample_rate: !ref <sample_rate>
|
||||||
|
truncate: !name:cosyvoice.dataset.processor.truncate
|
||||||
|
truncate_length: 24480 # must be a multiplier of hop_size
|
||||||
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||||
n_fft: 1024
|
n_fft: 1920
|
||||||
num_mels: 80
|
num_mels: 80
|
||||||
sampling_rate: !ref <sample_rate>
|
sampling_rate: !ref <sample_rate>
|
||||||
hop_size: 256
|
hop_size: 480
|
||||||
win_size: 1024
|
win_size: 1920
|
||||||
fmin: 0
|
fmin: 0
|
||||||
fmax: 8000
|
fmax: 8000
|
||||||
center: False
|
center: False
|
||||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
||||||
feat_extractor: !ref <feat_extractor>
|
feat_extractor: !ref <feat_extractor>
|
||||||
|
num_frames: 960
|
||||||
|
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
|
||||||
|
num_frames: 960
|
||||||
|
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
||||||
|
sample_rate: !ref <sample_rate>
|
||||||
|
hop_size: 480
|
||||||
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
||||||
normalize: True
|
normalize: True
|
||||||
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
||||||
@@ -169,6 +176,7 @@ batch: !name:cosyvoice.dataset.processor.batch
|
|||||||
padding: !name:cosyvoice.dataset.processor.padding
|
padding: !name:cosyvoice.dataset.processor.padding
|
||||||
use_spk_embedding: False # change to True during sft
|
use_spk_embedding: False # change to True during sft
|
||||||
|
|
||||||
|
|
||||||
# dataset processor pipeline
|
# dataset processor pipeline
|
||||||
data_pipeline: [
|
data_pipeline: [
|
||||||
!ref <parquet_opener>,
|
!ref <parquet_opener>,
|
||||||
@@ -177,22 +185,53 @@ data_pipeline: [
|
|||||||
!ref <resample>,
|
!ref <resample>,
|
||||||
!ref <compute_fbank>,
|
!ref <compute_fbank>,
|
||||||
!ref <parse_embedding>,
|
!ref <parse_embedding>,
|
||||||
|
!ref <compute_whisper_fbank>,
|
||||||
|
!ref <shuffle>,
|
||||||
|
!ref <sort>,
|
||||||
|
!ref <batch>,
|
||||||
|
!ref <padding>,
|
||||||
|
]
|
||||||
|
data_pipeline_gan: [
|
||||||
|
!ref <parquet_opener>,
|
||||||
|
!ref <tokenize>,
|
||||||
|
!ref <filter>,
|
||||||
|
!ref <resample>,
|
||||||
|
!ref <truncate>,
|
||||||
|
!ref <compute_fbank>,
|
||||||
|
!ref <compute_f0>,
|
||||||
|
!ref <parse_embedding>,
|
||||||
!ref <shuffle>,
|
!ref <shuffle>,
|
||||||
!ref <sort>,
|
!ref <sort>,
|
||||||
!ref <batch>,
|
!ref <batch>,
|
||||||
!ref <padding>,
|
!ref <padding>,
|
||||||
]
|
]
|
||||||
|
|
||||||
# train conf
|
# llm flow train conf
|
||||||
train_conf:
|
train_conf:
|
||||||
optim: adam
|
optim: adam
|
||||||
optim_conf:
|
optim_conf:
|
||||||
lr: 0.001 # change to 1e-5 during sft
|
lr: 1e-5 # change to 1e-5 during sft
|
||||||
scheduler: warmuplr # change to constantlr during sft
|
scheduler: constantlr # change to constantlr during sft
|
||||||
scheduler_conf:
|
scheduler_conf:
|
||||||
warmup_steps: 2500
|
warmup_steps: 2500
|
||||||
max_epoch: 200
|
max_epoch: 200
|
||||||
grad_clip: 5
|
grad_clip: 5
|
||||||
accum_grad: 2
|
accum_grad: 2
|
||||||
log_interval: 100
|
log_interval: 100
|
||||||
|
save_per_step: -1
|
||||||
|
|
||||||
|
# gan train conf
|
||||||
|
train_conf_gan:
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.0002 # use small lr for gan training
|
||||||
|
scheduler: constantlr
|
||||||
|
optim_d: adam
|
||||||
|
optim_conf_d:
|
||||||
|
lr: 0.0002 # use small lr for gan training
|
||||||
|
scheduler_d: constantlr
|
||||||
|
max_epoch: 200
|
||||||
|
grad_clip: 5
|
||||||
|
accum_grad: 1 # in gan training, accum_grad must be 1
|
||||||
|
log_interval: 100
|
||||||
save_per_step: -1
|
save_per_step: -1
|
||||||
1
examples/libritts/cosyvoice2/local
Symbolic link
1
examples/libritts/cosyvoice2/local
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../cosyvoice/local
|
||||||
1
examples/libritts/cosyvoice2/path.sh
Symbolic link
1
examples/libritts/cosyvoice2/path.sh
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../cosyvoice/path.sh
|
||||||
96
examples/libritts/cosyvoice2/run.sh
Normal file
96
examples/libritts/cosyvoice2/run.sh
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
||||||
|
. ./path.sh || exit 1;
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=3
|
||||||
|
|
||||||
|
data_url=www.openslr.org/resources/60
|
||||||
|
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
|
||||||
|
pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B
|
||||||
|
|
||||||
|
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||||
|
echo "Data Download"
|
||||||
|
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
|
||||||
|
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
|
||||||
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
|
mkdir -p data/$x
|
||||||
|
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# NOTE embedding/token extraction is not necessary now as we support online feature extraction
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||||
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
|
mkdir -p data/$x/parquet
|
||||||
|
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||||
|
--num_processes 10 \
|
||||||
|
--src_dir data/$x \
|
||||||
|
--des_dir data/$x/parquet
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# train llm
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
|
job_id=1986
|
||||||
|
dist_backend="nccl"
|
||||||
|
num_workers=2
|
||||||
|
prefetch=100
|
||||||
|
train_engine=torch_ddp
|
||||||
|
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||||
|
echo "Run train. We only support llm traning for now"
|
||||||
|
if [ $train_engine == 'deepspeed' ]; then
|
||||||
|
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
||||||
|
fi
|
||||||
|
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
||||||
|
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
||||||
|
for model in llm flow hifigan; do
|
||||||
|
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||||
|
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
||||||
|
../../../cosyvoice/bin/train.py \
|
||||||
|
--train_engine $train_engine \
|
||||||
|
--config conf/cosyvoice2.yaml \
|
||||||
|
--train_data data/train.data.list \
|
||||||
|
--cv_data data/dev.data.list \
|
||||||
|
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
|
||||||
|
--onnx_path $pretrained_model_dir \
|
||||||
|
--model $model \
|
||||||
|
--checkpoint $pretrained_model_dir/$model.pt \
|
||||||
|
--model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \
|
||||||
|
--tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \
|
||||||
|
--ddp.dist_backend $dist_backend \
|
||||||
|
--num_workers ${num_workers} \
|
||||||
|
--prefetch ${prefetch} \
|
||||||
|
--pin_memory \
|
||||||
|
--use_amp \
|
||||||
|
--deepspeed_config ./conf/ds_stage2.json \
|
||||||
|
--deepspeed.save_states model+optimizer
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# average model
|
||||||
|
average_num=5
|
||||||
|
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||||
|
for model in llm flow hifigan; do
|
||||||
|
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
|
||||||
|
echo "do model average and final checkpoint is $decode_checkpoint"
|
||||||
|
python cosyvoice/bin/average_model.py \
|
||||||
|
--dst_model $decode_checkpoint \
|
||||||
|
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
|
||||||
|
--num ${average_num} \
|
||||||
|
--val_best
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
||||||
|
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
||||||
|
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
||||||
|
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
||||||
|
fi
|
||||||
124
examples/libritts/cosyvoice2/run_dpo.sh
Normal file
124
examples/libritts/cosyvoice2/run_dpo.sh
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
||||||
|
. ./path.sh || exit 1;
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=3
|
||||||
|
|
||||||
|
data_url=www.openslr.org/resources/60
|
||||||
|
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
|
||||||
|
pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B
|
||||||
|
|
||||||
|
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||||
|
echo "Data Download"
|
||||||
|
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
|
||||||
|
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
|
||||||
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
|
mkdir -p data/$x
|
||||||
|
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
echo "Prepare negative samples using CosyVoice2-0.5B, this is also our reference model.
|
||||||
|
Here we use CosyVoice2-0.5B generated audio as reject sample for simplicity, you can use metric like wer/similarity."
|
||||||
|
for x in train-clean-100 train-clean-360 train-other-500; do
|
||||||
|
mkdir -p data/${x}_reject
|
||||||
|
python local/prepare_reject_sample.py --src_dir data/$x --des_dir data/${x}_reject --ref_model $pretrained_model_dir
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
||||||
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
|
../../../tools/extract_embedding.py --dir data/$x \
|
||||||
|
--onnx_path $pretrained_model_dir/campplus.onnx
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
||||||
|
for x in train-clean-100 train-clean-360 train-other-500 train-clean-100_reject train-clean-360_reject dev-clean dev-other test-clean test-other; do
|
||||||
|
../../../tools/extract_speech_token.py --dir data/$x \
|
||||||
|
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||||
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
|
mkdir -p data/$x/parquet
|
||||||
|
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||||
|
--num_processes 10 \
|
||||||
|
--dpo \
|
||||||
|
--src_dir data/$x \
|
||||||
|
--des_dir data/$x/parquet
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# train llm
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
|
job_id=1986
|
||||||
|
dist_backend="nccl"
|
||||||
|
num_workers=2
|
||||||
|
prefetch=100
|
||||||
|
train_engine=torch_ddp
|
||||||
|
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||||
|
echo "Run train. We only support llm traning for now"
|
||||||
|
if [ $train_engine == 'deepspeed' ]; then
|
||||||
|
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
||||||
|
fi
|
||||||
|
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
||||||
|
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
||||||
|
# NOTE only llm supports dpo
|
||||||
|
for model in llm; do
|
||||||
|
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||||
|
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
||||||
|
../../../cosyvoice/bin/train.py \
|
||||||
|
--train_engine $train_engine \
|
||||||
|
--config conf/cosyvoice2.yaml \
|
||||||
|
--train_data data/train.data.list \
|
||||||
|
--cv_data data/dev.data.list \
|
||||||
|
--onnx_path $pretrained_model_dir \
|
||||||
|
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
|
||||||
|
--model $model \
|
||||||
|
--checkpoint $pretrained_model_dir/$model.pt \
|
||||||
|
--ref_model $pretrained_model_dir/llm.pt \
|
||||||
|
--model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \
|
||||||
|
--tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \
|
||||||
|
--ddp.dist_backend $dist_backend \
|
||||||
|
--num_workers ${num_workers} \
|
||||||
|
--prefetch ${prefetch} \
|
||||||
|
--pin_memory \
|
||||||
|
--use_amp \
|
||||||
|
--dpo \
|
||||||
|
--deepspeed_config ./conf/ds_stage2.json \
|
||||||
|
--deepspeed.save_states model+optimizer
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# average model
|
||||||
|
average_num=5
|
||||||
|
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||||
|
for model in llm flow hifigan; do
|
||||||
|
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
|
||||||
|
echo "do model average and final checkpoint is $decode_checkpoint"
|
||||||
|
python cosyvoice/bin/average_model.py \
|
||||||
|
--dst_model $decode_checkpoint \
|
||||||
|
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
|
||||||
|
--num ${average_num} \
|
||||||
|
--val_best
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
||||||
|
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
||||||
|
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
||||||
|
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
||||||
|
fi
|
||||||
1
examples/libritts/cosyvoice2/tts_text.json
Symbolic link
1
examples/libritts/cosyvoice2/tts_text.json
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../cosyvoice/tts_text.json
|
||||||
227
examples/libritts/cosyvoice3/conf/cosyvoice3.yaml
Normal file
227
examples/libritts/cosyvoice3/conf/cosyvoice3.yaml
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
# set random seed, so that you may reproduce your result.
|
||||||
|
__set_seed1: !apply:random.seed [1986]
|
||||||
|
__set_seed2: !apply:numpy.random.seed [1986]
|
||||||
|
__set_seed3: !apply:torch.manual_seed [1986]
|
||||||
|
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
||||||
|
|
||||||
|
# fixed params
|
||||||
|
sample_rate: 24000
|
||||||
|
llm_input_size: 896
|
||||||
|
llm_output_size: 896
|
||||||
|
spk_embed_dim: 192
|
||||||
|
qwen_pretrain_path: ''
|
||||||
|
token_frame_rate: 25
|
||||||
|
token_mel_ratio: 2
|
||||||
|
|
||||||
|
# stream related params
|
||||||
|
chunk_size: 25 # streaming inference chunk size, in token
|
||||||
|
num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
|
||||||
|
|
||||||
|
# model params
|
||||||
|
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
||||||
|
# for system/third_party class/function, we do not require this.
|
||||||
|
llm: !new:cosyvoice.llm.llm.CosyVoice3LM
|
||||||
|
llm_input_size: !ref <llm_input_size>
|
||||||
|
llm_output_size: !ref <llm_output_size>
|
||||||
|
speech_token_size: 6561
|
||||||
|
length_normalized_loss: True
|
||||||
|
lsm_weight: 0
|
||||||
|
mix_ratio: [5, 15]
|
||||||
|
llm: !new:cosyvoice.llm.llm.Qwen2Encoder
|
||||||
|
pretrain_path: !ref <qwen_pretrain_path>
|
||||||
|
sampling: !name:cosyvoice.utils.common.ras_sampling
|
||||||
|
top_p: 0.8
|
||||||
|
top_k: 25
|
||||||
|
win_size: 10
|
||||||
|
tau_r: 0.1
|
||||||
|
|
||||||
|
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithDiT
|
||||||
|
input_size: 80
|
||||||
|
output_size: 80
|
||||||
|
spk_embed_dim: !ref <spk_embed_dim>
|
||||||
|
output_type: 'mel'
|
||||||
|
vocab_size: 6561
|
||||||
|
input_frame_rate: !ref <token_frame_rate>
|
||||||
|
only_mask_loss: True
|
||||||
|
token_mel_ratio: !ref <token_mel_ratio>
|
||||||
|
pre_lookahead_len: 3
|
||||||
|
pre_lookahead_layer: !new:cosyvoice.transformer.upsample_encoder.PreLookaheadLayer
|
||||||
|
in_channels: 80
|
||||||
|
channels: 1024
|
||||||
|
pre_lookahead_len: 3
|
||||||
|
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
|
||||||
|
in_channels: 240
|
||||||
|
n_spks: 1
|
||||||
|
spk_emb_dim: 80
|
||||||
|
cfm_params: !new:omegaconf.DictConfig
|
||||||
|
content:
|
||||||
|
sigma_min: 1e-06
|
||||||
|
solver: 'euler'
|
||||||
|
t_scheduler: 'cosine'
|
||||||
|
training_cfg_rate: 0.2
|
||||||
|
inference_cfg_rate: 0.7
|
||||||
|
reg_loss_type: 'l1'
|
||||||
|
estimator: !new:cosyvoice.flow.DiT.dit.DiT
|
||||||
|
dim: 1024
|
||||||
|
depth: 22
|
||||||
|
heads: 16
|
||||||
|
dim_head: 64
|
||||||
|
ff_mult: 2
|
||||||
|
mel_dim: 80
|
||||||
|
mu_dim: 80
|
||||||
|
spk_dim: 80
|
||||||
|
out_channels: 80
|
||||||
|
static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
|
||||||
|
num_decoding_left_chunks: !ref <num_decoding_left_chunks>
|
||||||
|
|
||||||
|
hift: !new:cosyvoice.hifigan.generator.CausalHiFTGenerator
|
||||||
|
in_channels: 80
|
||||||
|
base_channels: 512
|
||||||
|
nb_harmonics: 8
|
||||||
|
sampling_rate: !ref <sample_rate>
|
||||||
|
nsf_alpha: 0.1
|
||||||
|
nsf_sigma: 0.003
|
||||||
|
nsf_voiced_threshold: 10
|
||||||
|
upsample_rates: [8, 5, 3]
|
||||||
|
upsample_kernel_sizes: [16, 11, 7]
|
||||||
|
istft_params:
|
||||||
|
n_fft: 16
|
||||||
|
hop_len: 4
|
||||||
|
resblock_kernel_sizes: [3, 7, 11]
|
||||||
|
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||||
|
source_resblock_kernel_sizes: [7, 7, 11]
|
||||||
|
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||||
|
lrelu_slope: 0.1
|
||||||
|
audio_limit: 0.99
|
||||||
|
conv_pre_look_right: 4
|
||||||
|
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.CausalConvRNNF0Predictor
|
||||||
|
num_class: 1
|
||||||
|
in_channels: 80
|
||||||
|
cond_channels: 512
|
||||||
|
|
||||||
|
# gan related module
|
||||||
|
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
|
||||||
|
n_fft: 1920
|
||||||
|
num_mels: 80
|
||||||
|
sampling_rate: !ref <sample_rate>
|
||||||
|
hop_size: 480
|
||||||
|
win_size: 1920
|
||||||
|
fmin: 0
|
||||||
|
fmax: null
|
||||||
|
center: False
|
||||||
|
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
|
||||||
|
generator: !ref <hift>
|
||||||
|
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
|
||||||
|
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
|
||||||
|
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
|
||||||
|
mel_spec_transform: [
|
||||||
|
!ref <mel_spec_transform1>
|
||||||
|
]
|
||||||
|
|
||||||
|
# processor functions
|
||||||
|
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
||||||
|
get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
|
||||||
|
token_path: !ref <qwen_pretrain_path>
|
||||||
|
skip_special_tokens: True
|
||||||
|
version: cosyvoice3
|
||||||
|
allowed_special: 'all'
|
||||||
|
tokenize: !name:cosyvoice.dataset.processor.tokenize
|
||||||
|
get_tokenizer: !ref <get_tokenizer>
|
||||||
|
allowed_special: !ref <allowed_special>
|
||||||
|
filter: !name:cosyvoice.dataset.processor.filter
|
||||||
|
max_length: 6000
|
||||||
|
min_length: 100
|
||||||
|
token_max_length: 200
|
||||||
|
token_min_length: 1
|
||||||
|
resample: !name:cosyvoice.dataset.processor.resample
|
||||||
|
resample_rate: !ref <sample_rate>
|
||||||
|
truncate: !name:cosyvoice.dataset.processor.truncate
|
||||||
|
truncate_length: 24960 # must be a multiplier of hop_size and token_mel_ratio
|
||||||
|
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||||
|
n_fft: 1920
|
||||||
|
num_mels: 80
|
||||||
|
sampling_rate: !ref <sample_rate>
|
||||||
|
hop_size: 480
|
||||||
|
win_size: 1920
|
||||||
|
fmin: 0
|
||||||
|
fmax: null
|
||||||
|
center: False
|
||||||
|
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
||||||
|
feat_extractor: !ref <feat_extractor>
|
||||||
|
num_frames: 960
|
||||||
|
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
|
||||||
|
num_frames: 960
|
||||||
|
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
||||||
|
sample_rate: !ref <sample_rate>
|
||||||
|
hop_size: 480
|
||||||
|
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
||||||
|
normalize: True
|
||||||
|
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
||||||
|
shuffle_size: 1000
|
||||||
|
sort: !name:cosyvoice.dataset.processor.sort
|
||||||
|
sort_size: 500 # sort_size should be less than shuffle_size
|
||||||
|
batch: !name:cosyvoice.dataset.processor.batch
|
||||||
|
batch_type: 'dynamic'
|
||||||
|
max_frames_in_batch: 2000
|
||||||
|
padding: !name:cosyvoice.dataset.processor.padding
|
||||||
|
use_spk_embedding: False # change to True during sft
|
||||||
|
|
||||||
|
|
||||||
|
# dataset processor pipeline
|
||||||
|
data_pipeline: [
|
||||||
|
!ref <parquet_opener>,
|
||||||
|
!ref <tokenize>,
|
||||||
|
!ref <filter>,
|
||||||
|
!ref <resample>,
|
||||||
|
!ref <compute_fbank>,
|
||||||
|
!ref <parse_embedding>,
|
||||||
|
!ref <compute_whisper_fbank>,
|
||||||
|
!ref <shuffle>,
|
||||||
|
!ref <sort>,
|
||||||
|
!ref <batch>,
|
||||||
|
!ref <padding>,
|
||||||
|
]
|
||||||
|
data_pipeline_gan: [
|
||||||
|
!ref <parquet_opener>,
|
||||||
|
!ref <tokenize>,
|
||||||
|
!ref <filter>,
|
||||||
|
!ref <resample>,
|
||||||
|
!ref <truncate>,
|
||||||
|
!ref <compute_fbank>,
|
||||||
|
!ref <compute_f0>,
|
||||||
|
!ref <parse_embedding>,
|
||||||
|
!ref <shuffle>,
|
||||||
|
!ref <sort>,
|
||||||
|
!ref <batch>,
|
||||||
|
!ref <padding>,
|
||||||
|
]
|
||||||
|
|
||||||
|
# llm flow train conf
|
||||||
|
train_conf:
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 1e-5 # change to 1e-5 during sft
|
||||||
|
scheduler: constantlr # change to constantlr during sft
|
||||||
|
scheduler_conf:
|
||||||
|
warmup_steps: 2500
|
||||||
|
max_epoch: 200
|
||||||
|
grad_clip: 5
|
||||||
|
accum_grad: 2
|
||||||
|
log_interval: 100
|
||||||
|
save_per_step: -1
|
||||||
|
|
||||||
|
# gan train conf
|
||||||
|
train_conf_gan:
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.0002 # use small lr for gan training
|
||||||
|
scheduler: constantlr
|
||||||
|
optim_d: adam
|
||||||
|
optim_conf_d:
|
||||||
|
lr: 0.0002 # use small lr for gan training
|
||||||
|
scheduler_d: constantlr
|
||||||
|
max_epoch: 200
|
||||||
|
grad_clip: 5
|
||||||
|
accum_grad: 1 # in gan training, accum_grad must be 1
|
||||||
|
log_interval: 100
|
||||||
|
save_per_step: -1
|
||||||
42
examples/libritts/cosyvoice3/conf/ds_stage2.json
Normal file
42
examples/libritts/cosyvoice3/conf/ds_stage2.json
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
{
|
||||||
|
"train_micro_batch_size_per_gpu": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"steps_per_print": 100,
|
||||||
|
"gradient_clipping": 5,
|
||||||
|
"fp16": {
|
||||||
|
"enabled": false,
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 16,
|
||||||
|
"loss_scale_window": 256,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"consecutive_hysteresis": false,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": false
|
||||||
|
},
|
||||||
|
"zero_force_ds_cpu_optimizer": false,
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "none",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 5e8,
|
||||||
|
"overlap_comm": false,
|
||||||
|
"reduce_scatter": true,
|
||||||
|
"reduce_bucket_size": 5e8,
|
||||||
|
"contiguous_gradients" : true
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": 0.001,
|
||||||
|
"weight_decay": 0.0001,
|
||||||
|
"torch_adam": true,
|
||||||
|
"adam_w_mode": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
1
examples/libritts/cosyvoice3/local
Symbolic link
1
examples/libritts/cosyvoice3/local
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../cosyvoice/local
|
||||||
1
examples/libritts/cosyvoice3/path.sh
Symbolic link
1
examples/libritts/cosyvoice3/path.sh
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../cosyvoice/path.sh
|
||||||
97
examples/libritts/cosyvoice3/run.sh
Normal file
97
examples/libritts/cosyvoice3/run.sh
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2024 Alibaba Inc. All Rights Reserved.
|
||||||
|
. ./path.sh || exit 1;
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=3
|
||||||
|
|
||||||
|
data_url=www.openslr.org/resources/60
|
||||||
|
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
|
||||||
|
pretrained_model_dir=../../../pretrained_models/Fun-CosyVoice3-0.5B
|
||||||
|
|
||||||
|
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||||
|
echo "Data Download"
|
||||||
|
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
|
||||||
|
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
|
||||||
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
|
mkdir -p data/$x
|
||||||
|
# NOTE in CosyVoice3, we add instruct in sequence
|
||||||
|
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x --instruct "You are a helpful assistant.<|endofprompt|>"
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# NOTE embedding/token extraction is not necessary now as we support online feature extraction
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||||
|
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||||
|
mkdir -p data/$x/parquet
|
||||||
|
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||||
|
--num_processes 10 \
|
||||||
|
--src_dir data/$x \
|
||||||
|
--des_dir data/$x/parquet
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# train llm
|
||||||
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
|
job_id=1986
|
||||||
|
dist_backend="nccl"
|
||||||
|
num_workers=2
|
||||||
|
prefetch=100
|
||||||
|
train_engine=torch_ddp
|
||||||
|
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||||
|
echo "Run train. We only support llm traning for now"
|
||||||
|
if [ $train_engine == 'deepspeed' ]; then
|
||||||
|
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
|
||||||
|
fi
|
||||||
|
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
||||||
|
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
||||||
|
for model in llm flow hifigan; do
|
||||||
|
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||||
|
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
|
||||||
|
../../../cosyvoice/bin/train.py \
|
||||||
|
--train_engine $train_engine \
|
||||||
|
--config conf/cosyvoice3.yaml \
|
||||||
|
--train_data data/train.data.list \
|
||||||
|
--cv_data data/dev.data.list \
|
||||||
|
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
|
||||||
|
--onnx_path $pretrained_model_dir \
|
||||||
|
--model $model \
|
||||||
|
--checkpoint $pretrained_model_dir/$model.pt \
|
||||||
|
--model_dir `pwd`/exp/cosyvoice3/$model/$train_engine \
|
||||||
|
--tensorboard_dir `pwd`/tensorboard/cosyvoice3/$model/$train_engine \
|
||||||
|
--ddp.dist_backend $dist_backend \
|
||||||
|
--num_workers ${num_workers} \
|
||||||
|
--prefetch ${prefetch} \
|
||||||
|
--pin_memory \
|
||||||
|
--use_amp \
|
||||||
|
--deepspeed_config ./conf/ds_stage2.json \
|
||||||
|
--deepspeed.save_states model+optimizer
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# average model
|
||||||
|
average_num=5
|
||||||
|
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||||
|
for model in llm flow hifigan; do
|
||||||
|
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
|
||||||
|
echo "do model average and final checkpoint is $decode_checkpoint"
|
||||||
|
python cosyvoice/bin/average_model.py \
|
||||||
|
--dst_model $decode_checkpoint \
|
||||||
|
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
|
||||||
|
--num ${average_num} \
|
||||||
|
--val_best
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
|
||||||
|
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
||||||
|
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
||||||
|
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
||||||
|
fi
|
||||||
1
examples/magicdata-read/cosyvoice/conf
Symbolic link
1
examples/magicdata-read/cosyvoice/conf
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../libritts/cosyvoice/conf
|
||||||
@@ -1,198 +0,0 @@
|
|||||||
# set random seed, so that you may reproduce your result.
|
|
||||||
__set_seed1: !apply:random.seed [1986]
|
|
||||||
__set_seed2: !apply:numpy.random.seed [1986]
|
|
||||||
__set_seed3: !apply:torch.manual_seed [1986]
|
|
||||||
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
|
||||||
|
|
||||||
# fixed params
|
|
||||||
sample_rate: 22050
|
|
||||||
text_encoder_input_size: 512
|
|
||||||
llm_input_size: 1024
|
|
||||||
llm_output_size: 1024
|
|
||||||
spk_embed_dim: 192
|
|
||||||
|
|
||||||
# model params
|
|
||||||
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
|
||||||
# for system/third_party class/function, we do not require this.
|
|
||||||
llm: !new:cosyvoice.llm.llm.TransformerLM
|
|
||||||
text_encoder_input_size: !ref <text_encoder_input_size>
|
|
||||||
llm_input_size: !ref <llm_input_size>
|
|
||||||
llm_output_size: !ref <llm_output_size>
|
|
||||||
text_token_size: 51866
|
|
||||||
speech_token_size: 4096
|
|
||||||
length_normalized_loss: True
|
|
||||||
lsm_weight: 0
|
|
||||||
spk_embed_dim: !ref <spk_embed_dim>
|
|
||||||
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
|
||||||
input_size: !ref <text_encoder_input_size>
|
|
||||||
output_size: 1024
|
|
||||||
attention_heads: 8
|
|
||||||
linear_units: 2048
|
|
||||||
num_blocks: 3
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.0
|
|
||||||
normalize_before: True
|
|
||||||
input_layer: 'linear'
|
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
use_cnn_module: False
|
|
||||||
macaron_style: False
|
|
||||||
use_dynamic_chunk: False
|
|
||||||
use_dynamic_left_chunk: False
|
|
||||||
static_chunk_size: 1
|
|
||||||
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
|
|
||||||
input_size: !ref <llm_input_size>
|
|
||||||
output_size: !ref <llm_output_size>
|
|
||||||
attention_heads: 8
|
|
||||||
linear_units: 2048
|
|
||||||
num_blocks: 7
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.0
|
|
||||||
input_layer: 'linear_legacy'
|
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
static_chunk_size: 1
|
|
||||||
|
|
||||||
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
|
||||||
input_size: 512
|
|
||||||
output_size: 80
|
|
||||||
spk_embed_dim: !ref <spk_embed_dim>
|
|
||||||
output_type: 'mel'
|
|
||||||
vocab_size: 4096
|
|
||||||
input_frame_rate: 50
|
|
||||||
only_mask_loss: True
|
|
||||||
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
|
||||||
output_size: 512
|
|
||||||
attention_heads: 4
|
|
||||||
linear_units: 1024
|
|
||||||
num_blocks: 3
|
|
||||||
dropout_rate: 0.1
|
|
||||||
positional_dropout_rate: 0.1
|
|
||||||
attention_dropout_rate: 0.1
|
|
||||||
normalize_before: True
|
|
||||||
input_layer: 'linear'
|
|
||||||
pos_enc_layer_type: 'rel_pos_espnet'
|
|
||||||
selfattention_layer_type: 'rel_selfattn'
|
|
||||||
input_size: 512
|
|
||||||
use_cnn_module: False
|
|
||||||
macaron_style: False
|
|
||||||
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
|
|
||||||
channels: 80
|
|
||||||
sampling_ratios: [1, 1, 1, 1]
|
|
||||||
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
|
|
||||||
in_channels: 240
|
|
||||||
n_spks: 1
|
|
||||||
spk_emb_dim: 80
|
|
||||||
cfm_params: !new:omegaconf.DictConfig
|
|
||||||
content:
|
|
||||||
sigma_min: 1e-06
|
|
||||||
solver: 'euler'
|
|
||||||
t_scheduler: 'cosine'
|
|
||||||
training_cfg_rate: 0.2
|
|
||||||
inference_cfg_rate: 0.7
|
|
||||||
reg_loss_type: 'l1'
|
|
||||||
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
|
|
||||||
in_channels: 320
|
|
||||||
out_channels: 80
|
|
||||||
channels: [256, 256]
|
|
||||||
dropout: 0.0
|
|
||||||
attention_head_dim: 64
|
|
||||||
n_blocks: 4
|
|
||||||
num_mid_blocks: 8
|
|
||||||
num_heads: 8
|
|
||||||
act_fn: 'gelu'
|
|
||||||
|
|
||||||
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
|
||||||
in_channels: 80
|
|
||||||
base_channels: 512
|
|
||||||
nb_harmonics: 8
|
|
||||||
sampling_rate: !ref <sample_rate>
|
|
||||||
nsf_alpha: 0.1
|
|
||||||
nsf_sigma: 0.003
|
|
||||||
nsf_voiced_threshold: 10
|
|
||||||
upsample_rates: [8, 8]
|
|
||||||
upsample_kernel_sizes: [16, 16]
|
|
||||||
istft_params:
|
|
||||||
n_fft: 16
|
|
||||||
hop_len: 4
|
|
||||||
resblock_kernel_sizes: [3, 7, 11]
|
|
||||||
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
|
||||||
source_resblock_kernel_sizes: [7, 11]
|
|
||||||
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
|
|
||||||
lrelu_slope: 0.1
|
|
||||||
audio_limit: 0.99
|
|
||||||
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
|
||||||
num_class: 1
|
|
||||||
in_channels: 80
|
|
||||||
cond_channels: 512
|
|
||||||
|
|
||||||
# processor functions
|
|
||||||
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
|
||||||
get_tokenizer: !name:whisper.tokenizer.get_tokenizer
|
|
||||||
multilingual: True
|
|
||||||
num_languages: 100
|
|
||||||
language: 'en'
|
|
||||||
task: 'transcribe'
|
|
||||||
allowed_special: 'all'
|
|
||||||
tokenize: !name:cosyvoice.dataset.processor.tokenize
|
|
||||||
get_tokenizer: !ref <get_tokenizer>
|
|
||||||
allowed_special: !ref <allowed_special>
|
|
||||||
filter: !name:cosyvoice.dataset.processor.filter
|
|
||||||
max_length: 40960
|
|
||||||
min_length: 0
|
|
||||||
token_max_length: 200
|
|
||||||
token_min_length: 1
|
|
||||||
resample: !name:cosyvoice.dataset.processor.resample
|
|
||||||
resample_rate: !ref <sample_rate>
|
|
||||||
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
|
||||||
n_fft: 1024
|
|
||||||
num_mels: 80
|
|
||||||
sampling_rate: !ref <sample_rate>
|
|
||||||
hop_size: 256
|
|
||||||
win_size: 1024
|
|
||||||
fmin: 0
|
|
||||||
fmax: 8000
|
|
||||||
center: False
|
|
||||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
|
||||||
feat_extractor: !ref <feat_extractor>
|
|
||||||
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
|
||||||
normalize: True
|
|
||||||
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
|
||||||
shuffle_size: 1000
|
|
||||||
sort: !name:cosyvoice.dataset.processor.sort
|
|
||||||
sort_size: 500 # sort_size should be less than shuffle_size
|
|
||||||
batch: !name:cosyvoice.dataset.processor.batch
|
|
||||||
batch_type: 'dynamic'
|
|
||||||
max_frames_in_batch: 12000
|
|
||||||
padding: !name:cosyvoice.dataset.processor.padding
|
|
||||||
use_spk_embedding: False # change to True during sft
|
|
||||||
|
|
||||||
# dataset processor pipeline
|
|
||||||
data_pipeline: [
|
|
||||||
!ref <parquet_opener>,
|
|
||||||
!ref <tokenize>,
|
|
||||||
!ref <filter>,
|
|
||||||
!ref <resample>,
|
|
||||||
!ref <compute_fbank>,
|
|
||||||
!ref <parse_embedding>,
|
|
||||||
!ref <shuffle>,
|
|
||||||
!ref <sort>,
|
|
||||||
!ref <batch>,
|
|
||||||
!ref <padding>,
|
|
||||||
]
|
|
||||||
|
|
||||||
# train conf
|
|
||||||
train_conf:
|
|
||||||
optim: adam
|
|
||||||
optim_conf:
|
|
||||||
lr: 0.002 # change to 0.001 if you want to train flow from scratch
|
|
||||||
scheduler: warmuplr
|
|
||||||
scheduler_conf:
|
|
||||||
warmup_steps: 25000
|
|
||||||
max_epoch: 200
|
|
||||||
grad_clip: 5
|
|
||||||
accum_grad: 2
|
|
||||||
log_interval: 100
|
|
||||||
save_per_step: -1
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
../../../cosyvoice
|
|
||||||
@@ -6,6 +6,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
|
utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
|
||||||
with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
|
with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
|
||||||
@@ -40,6 +41,7 @@ def main():
|
|||||||
f.write('{} {}\n'.format(k, ' '.join(v)))
|
f.write('{} {}\n'.format(k, ' '.join(v)))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--src_dir',
|
parser.add_argument('--src_dir',
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
|
||||||
export PYTHONIOENCODING=UTF-8
|
|
||||||
export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH
|
|
||||||
1
examples/magicdata-read/cosyvoice/path.sh
Symbolic link
1
examples/magicdata-read/cosyvoice/path.sh
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../libritts/cosyvoice/path.sh
|
||||||
@@ -27,7 +27,7 @@ fi
|
|||||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
|
||||||
for x in dev test train; do
|
for x in dev test train; do
|
||||||
tools/extract_embedding.py --dir data/$x \
|
../../../tools/extract_embedding.py --dir data/$x \
|
||||||
--onnx_path $pretrained_model_dir/campplus.onnx
|
--onnx_path $pretrained_model_dir/campplus.onnx
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
@@ -35,7 +35,7 @@ fi
|
|||||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
|
||||||
for x in dev test train; do
|
for x in dev test train; do
|
||||||
tools/extract_speech_token.py --dir data/$x \
|
../../../tools/extract_speech_token.py --dir data/$x \
|
||||||
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
|
--onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
@@ -44,30 +44,13 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
|
||||||
for x in dev test train; do
|
for x in dev test train; do
|
||||||
mkdir -p data/$x/parquet
|
mkdir -p data/$x/parquet
|
||||||
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
../../../tools/make_parquet_list.py --num_utts_per_parquet 1000 \
|
||||||
--num_processes 10 \
|
--num_processes 10 \
|
||||||
--src_dir data/$x \
|
--src_dir data/$x \
|
||||||
--des_dir data/$x/parquet
|
--des_dir data/$x/parquet
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# inference
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|
||||||
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
|
|
||||||
for mode in sft zero_shot; do
|
|
||||||
python cosyvoice/bin/inference.py --mode $mode \
|
|
||||||
--gpu 0 \
|
|
||||||
--config conf/cosyvoice.yaml \
|
|
||||||
--prompt_data data/test/parquet/data.list \
|
|
||||||
--prompt_utt2data data/test/parquet/utt2data.list \
|
|
||||||
--tts_text `pwd`/tts_text.json \
|
|
||||||
--llm_model $pretrained_model_dir/llm.pt \
|
|
||||||
--flow_model $pretrained_model_dir/flow.pt \
|
|
||||||
--hifigan_model $pretrained_model_dir/hift.pt \
|
|
||||||
--result_dir `pwd`/exp/cosyvoice/test/$mode
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
# train llm
|
# train llm
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||||
@@ -83,10 +66,10 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|||||||
fi
|
fi
|
||||||
cp data/train/parquet/data.list data/train.data.list
|
cp data/train/parquet/data.list data/train.data.list
|
||||||
cp data/dev/parquet/data.list data/dev.data.list
|
cp data/dev/parquet/data.list data/dev.data.list
|
||||||
for model in llm; do
|
for model in llm flow hifigan; do
|
||||||
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
||||||
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
||||||
cosyvoice/bin/train.py \
|
../../../cosyvoice/bin/train.py \
|
||||||
--train_engine $train_engine \
|
--train_engine $train_engine \
|
||||||
--config conf/cosyvoice.yaml \
|
--config conf/cosyvoice.yaml \
|
||||||
--train_data data/train.data.list \
|
--train_data data/train.data.list \
|
||||||
@@ -99,7 +82,28 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|||||||
--num_workers ${num_workers} \
|
--num_workers ${num_workers} \
|
||||||
--prefetch ${prefetch} \
|
--prefetch ${prefetch} \
|
||||||
--pin_memory \
|
--pin_memory \
|
||||||
|
--use_amp \
|
||||||
--deepspeed_config ./conf/ds_stage2.json \
|
--deepspeed_config ./conf/ds_stage2.json \
|
||||||
--deepspeed.save_states model+optimizer
|
--deepspeed.save_states model+optimizer
|
||||||
done
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# average model
|
||||||
|
average_num=5
|
||||||
|
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||||
|
for model in llm flow hifigan; do
|
||||||
|
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
|
||||||
|
echo "do model average and final checkpoint is $decode_checkpoint"
|
||||||
|
python cosyvoice/bin/average_model.py \
|
||||||
|
--dst_model $decode_checkpoint \
|
||||||
|
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
|
||||||
|
--num ${average_num} \
|
||||||
|
--val_best
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
||||||
|
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
||||||
|
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
||||||
|
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
||||||
fi
|
fi
|
||||||
@@ -1 +0,0 @@
|
|||||||
../../../tools
|
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||||
|
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684
|
||||||
conformer==0.3.2
|
conformer==0.3.2
|
||||||
deepspeed==0.14.2; sys_platform == 'linux'
|
deepspeed==0.15.1; sys_platform == 'linux'
|
||||||
diffusers==0.27.2
|
diffusers==0.29.0
|
||||||
|
fastapi==0.115.6
|
||||||
|
fastapi-cli==0.0.4
|
||||||
gdown==5.1.0
|
gdown==5.1.0
|
||||||
gradio==4.32.2
|
gradio==5.4.0
|
||||||
grpcio==1.57.0
|
grpcio==1.57.0
|
||||||
grpcio-tools==1.57.0
|
grpcio-tools==1.57.0
|
||||||
hydra-core==1.3.2
|
hydra-core==1.3.2
|
||||||
@@ -12,20 +15,28 @@ inflect==7.3.1
|
|||||||
librosa==0.10.2
|
librosa==0.10.2
|
||||||
lightning==2.2.4
|
lightning==2.2.4
|
||||||
matplotlib==3.7.5
|
matplotlib==3.7.5
|
||||||
modelscope==1.15.0
|
modelscope==1.20.0
|
||||||
networkx==3.1
|
networkx==3.1
|
||||||
|
numpy==1.26.4
|
||||||
omegaconf==2.3.0
|
omegaconf==2.3.0
|
||||||
onnxruntime-gpu==1.16.0; sys_platform == 'linux'
|
onnx==1.16.0
|
||||||
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
|
onnxruntime-gpu==1.18.0; sys_platform == 'linux'
|
||||||
|
onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32'
|
||||||
openai-whisper==20231117
|
openai-whisper==20231117
|
||||||
protobuf==4.25
|
protobuf==4.25
|
||||||
|
pyarrow==18.1.0
|
||||||
pydantic==2.7.0
|
pydantic==2.7.0
|
||||||
|
pyworld==0.3.4
|
||||||
rich==13.7.1
|
rich==13.7.1
|
||||||
soundfile==0.12.1
|
soundfile==0.12.1
|
||||||
tensorboard==2.14.0
|
tensorboard==2.14.0
|
||||||
torch==2.0.1
|
tensorrt-cu12==10.13.3.9; sys_platform == 'linux'
|
||||||
torchaudio==2.0.2
|
tensorrt-cu12-bindings==10.13.3.9; sys_platform == 'linux'
|
||||||
|
tensorrt-cu12-libs==10.13.3.9; sys_platform == 'linux'
|
||||||
|
torch==2.3.1
|
||||||
|
torchaudio==2.3.1
|
||||||
|
transformers==4.51.3
|
||||||
|
x-transformers==2.11.24
|
||||||
|
uvicorn==0.30.0
|
||||||
|
wetext==0.0.4
|
||||||
wget==3.2
|
wget==3.2
|
||||||
fastapi==0.111.0
|
|
||||||
fastapi-cli==0.0.4
|
|
||||||
WeTextProcessing==1.0.3
|
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ WORKDIR /opt/CosyVoice
|
|||||||
|
|
||||||
RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list
|
RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list
|
||||||
RUN apt-get update -y
|
RUN apt-get update -y
|
||||||
RUN apt-get -y install git unzip git-lfs
|
RUN apt-get -y install git unzip git-lfs g++
|
||||||
RUN git lfs install
|
RUN git lfs install
|
||||||
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||||
# here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
|
# here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
|
||||||
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
|
||||||
RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
|
RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
|
||||||
@@ -1,56 +1,69 @@
|
|||||||
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def saveResponse(path, response):
|
|
||||||
# 以二进制写入模式打开文件
|
|
||||||
with open(path, 'wb') as file:
|
|
||||||
# 将响应的二进制内容写入文件
|
|
||||||
file.write(response.content)
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
api = args.api_base
|
url = "http://{}:{}/inference_{}".format(args.host, args.port, args.mode)
|
||||||
if args.mode == 'sft':
|
if args.mode == 'sft':
|
||||||
url = api + "/api/inference/sft"
|
|
||||||
payload={
|
|
||||||
'tts': args.tts_text,
|
|
||||||
'role': args.spk_id
|
|
||||||
}
|
|
||||||
response = requests.request("POST", url, data=payload)
|
|
||||||
saveResponse(args.tts_wav, response)
|
|
||||||
elif args.mode == 'zero_shot':
|
|
||||||
url = api + "/api/inference/zero-shot"
|
|
||||||
payload={
|
|
||||||
'tts': args.tts_text,
|
|
||||||
'prompt': args.prompt_text
|
|
||||||
}
|
|
||||||
files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
|
|
||||||
response = requests.request("POST", url, data=payload, files=files)
|
|
||||||
saveResponse(args.tts_wav, response)
|
|
||||||
elif args.mode == 'cross_lingual':
|
|
||||||
url = api + "/api/inference/cross-lingual"
|
|
||||||
payload={
|
|
||||||
'tts': args.tts_text,
|
|
||||||
}
|
|
||||||
files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
|
|
||||||
response = requests.request("POST", url, data=payload, files=files)
|
|
||||||
saveResponse(args.tts_wav, response)
|
|
||||||
else:
|
|
||||||
url = api + "/api/inference/instruct"
|
|
||||||
payload = {
|
payload = {
|
||||||
'tts': args.tts_text,
|
'tts_text': args.tts_text,
|
||||||
'role': args.spk_id,
|
'spk_id': args.spk_id
|
||||||
'instruct': args.instruct_text
|
|
||||||
}
|
}
|
||||||
response = requests.request("POST", url, data=payload)
|
response = requests.request("GET", url, data=payload, stream=True)
|
||||||
saveResponse(args.tts_wav, response)
|
elif args.mode == 'zero_shot':
|
||||||
logging.info("Response save to {}", args.tts_wav)
|
payload = {
|
||||||
|
'tts_text': args.tts_text,
|
||||||
|
'prompt_text': args.prompt_text
|
||||||
|
}
|
||||||
|
files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
|
||||||
|
response = requests.request("GET", url, data=payload, files=files, stream=True)
|
||||||
|
elif args.mode == 'cross_lingual':
|
||||||
|
payload = {
|
||||||
|
'tts_text': args.tts_text,
|
||||||
|
}
|
||||||
|
files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
|
||||||
|
response = requests.request("GET", url, data=payload, files=files, stream=True)
|
||||||
|
else:
|
||||||
|
payload = {
|
||||||
|
'tts_text': args.tts_text,
|
||||||
|
'spk_id': args.spk_id,
|
||||||
|
'instruct_text': args.instruct_text
|
||||||
|
}
|
||||||
|
response = requests.request("GET", url, data=payload, stream=True)
|
||||||
|
tts_audio = b''
|
||||||
|
for r in response.iter_content(chunk_size=16000):
|
||||||
|
tts_audio += r
|
||||||
|
tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
|
||||||
|
logging.info('save response to {}'.format(args.tts_wav))
|
||||||
|
torchaudio.save(args.tts_wav, tts_speech, target_sr)
|
||||||
|
logging.info('get response')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--api_base',
|
parser.add_argument('--host',
|
||||||
type=str,
|
type=str,
|
||||||
default='http://127.0.0.1:6006')
|
default='0.0.0.0')
|
||||||
|
parser.add_argument('--port',
|
||||||
|
type=int,
|
||||||
|
default='50000')
|
||||||
parser.add_argument('--mode',
|
parser.add_argument('--mode',
|
||||||
default='sft',
|
default='sft',
|
||||||
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
|
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
|
||||||
@@ -66,10 +79,11 @@ if __name__ == "__main__":
|
|||||||
default='希望你以后能够做的比我还好呦。')
|
default='希望你以后能够做的比我还好呦。')
|
||||||
parser.add_argument('--prompt_wav',
|
parser.add_argument('--prompt_wav',
|
||||||
type=str,
|
type=str,
|
||||||
default='../../../zero_shot_prompt.wav')
|
default='../../../asset/zero_shot_prompt.wav')
|
||||||
parser.add_argument('--instruct_text',
|
parser.add_argument('--instruct_text',
|
||||||
type=str,
|
type=str,
|
||||||
default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
|
default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
|
||||||
|
Fights with fervor for justice, but struggles with impulsiveness.')
|
||||||
parser.add_argument('--tts_wav',
|
parser.add_argument('--tts_wav',
|
||||||
type=str,
|
type=str,
|
||||||
default='demo.wav')
|
default='demo.wav')
|
||||||
|
|||||||
@@ -1,119 +1,95 @@
|
|||||||
# Set inference model
|
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||||
# export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct
|
#
|
||||||
# For development
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# fastapi dev --port 6006 fastapi_server.py
|
# you may not use this file except in compliance with the License.
|
||||||
# For production deployment
|
# You may obtain a copy of the License at
|
||||||
# fastapi run --port 6006 fastapi_server.py
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import io,time
|
import argparse
|
||||||
from fastapi import FastAPI, Response, File, UploadFile, Form
|
import logging
|
||||||
from fastapi.responses import HTMLResponse
|
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||||
from fastapi.middleware.cors import CORSMiddleware #引入 CORS中间件模块
|
from fastapi import FastAPI, UploadFile, Form, File
|
||||||
from contextlib import asynccontextmanager
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
import uvicorn
|
||||||
|
import numpy as np
|
||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
||||||
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
from cosyvoice.cli.cosyvoice import AutoModel
|
||||||
from cosyvoice.utils.file_utils import load_wav
|
from cosyvoice.utils.file_utils import load_wav
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
import logging
|
|
||||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
class LaunchFailed(Exception):
|
app = FastAPI()
|
||||||
pass
|
# set cross region allowance
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT")
|
|
||||||
if model_dir:
|
|
||||||
logging.info("MODEL_DIR is {}", model_dir)
|
|
||||||
app.cosyvoice = CosyVoice(model_dir)
|
|
||||||
# sft usage
|
|
||||||
logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks())
|
|
||||||
else:
|
|
||||||
raise LaunchFailed("MODEL_DIR environment must set")
|
|
||||||
yield
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
|
||||||
|
|
||||||
#设置允许访问的域名
|
|
||||||
origins = ["*"] #"*",即为所有,也可以改为允许的特定ip。
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=origins, #设置允许的origins来源
|
allow_origins=["*"],
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。
|
allow_methods=["*"],
|
||||||
allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。
|
allow_headers=["*"])
|
||||||
|
|
||||||
def buildResponse(output):
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
torchaudio.save(buffer, output, 22050, format="wav")
|
|
||||||
buffer.seek(0)
|
|
||||||
return Response(content=buffer.read(-1), media_type="audio/wav")
|
|
||||||
|
|
||||||
@app.post("/api/inference/sft")
|
def generate_data(model_output):
|
||||||
@app.get("/api/inference/sft")
|
for i in model_output:
|
||||||
async def sft(tts: str = Form(), role: str = Form()):
|
tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
||||||
start = time.process_time()
|
yield tts_audio
|
||||||
output = app.cosyvoice.inference_sft(tts, role)
|
|
||||||
end = time.process_time()
|
|
||||||
logging.info("infer time is {} seconds", end-start)
|
|
||||||
return buildResponse(output['tts_speech'])
|
|
||||||
|
|
||||||
@app.post("/api/inference/zero-shot")
|
|
||||||
async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()):
|
|
||||||
start = time.process_time()
|
|
||||||
prompt_speech = load_wav(audio.file, 16000)
|
|
||||||
prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
|
|
||||||
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
|
||||||
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
|
||||||
|
|
||||||
output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k)
|
@app.get("/inference_sft")
|
||||||
end = time.process_time()
|
@app.post("/inference_sft")
|
||||||
logging.info("infer time is {} seconds", end-start)
|
async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
|
||||||
return buildResponse(output['tts_speech'])
|
model_output = cosyvoice.inference_sft(tts_text, spk_id)
|
||||||
|
return StreamingResponse(generate_data(model_output))
|
||||||
|
|
||||||
@app.post("/api/inference/cross-lingual")
|
|
||||||
async def crossLingual(tts: str = Form(), audio: UploadFile = File()):
|
|
||||||
start = time.process_time()
|
|
||||||
prompt_speech = load_wav(audio.file, 16000)
|
|
||||||
prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
|
|
||||||
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
|
||||||
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
|
||||||
|
|
||||||
output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k)
|
@app.get("/inference_zero_shot")
|
||||||
end = time.process_time()
|
@app.post("/inference_zero_shot")
|
||||||
logging.info("infer time is {} seconds", end-start)
|
async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
|
||||||
return buildResponse(output['tts_speech'])
|
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
||||||
|
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
|
||||||
|
return StreamingResponse(generate_data(model_output))
|
||||||
|
|
||||||
@app.post("/api/inference/instruct")
|
|
||||||
@app.get("/api/inference/instruct")
|
|
||||||
async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()):
|
|
||||||
start = time.process_time()
|
|
||||||
output = app.cosyvoice.inference_instruct(tts, role, instruct)
|
|
||||||
end = time.process_time()
|
|
||||||
logging.info("infer time is {} seconds", end-start)
|
|
||||||
return buildResponse(output['tts_speech'])
|
|
||||||
|
|
||||||
@app.get("/api/roles")
|
@app.get("/inference_cross_lingual")
|
||||||
async def roles():
|
@app.post("/inference_cross_lingual")
|
||||||
return {"roles": app.cosyvoice.list_avaliable_spks()}
|
async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
|
||||||
|
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
||||||
|
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
|
||||||
|
return StreamingResponse(generate_data(model_output))
|
||||||
|
|
||||||
@app.get("/", response_class=HTMLResponse)
|
|
||||||
async def root():
|
@app.get("/inference_instruct")
|
||||||
return """
|
@app.post("/inference_instruct")
|
||||||
<!DOCTYPE html>
|
async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
|
||||||
<html lang=zh-cn>
|
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
|
||||||
<head>
|
return StreamingResponse(generate_data(model_output))
|
||||||
<meta charset=utf-8>
|
|
||||||
<title>Api information</title>
|
|
||||||
</head>
|
@app.get("/inference_instruct2")
|
||||||
<body>
|
@app.post("/inference_instruct2")
|
||||||
Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a>
|
async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File()):
|
||||||
</body>
|
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
||||||
</html>
|
model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k)
|
||||||
"""
|
return StreamingResponse(generate_data(model_output))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--port',
|
||||||
|
type=int,
|
||||||
|
default=50000)
|
||||||
|
parser.add_argument('--model_dir',
|
||||||
|
type=str,
|
||||||
|
default='iic/CosyVoice2-0.5B',
|
||||||
|
help='local path or modelscope repo id')
|
||||||
|
args = parser.parse_args()
|
||||||
|
cosyvoice = AutoModel(model_dir=args.model_dir)
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||||
|
|||||||
@@ -61,8 +61,11 @@ def main():
|
|||||||
request.instruct_request.CopyFrom(instruct_request)
|
request.instruct_request.CopyFrom(instruct_request)
|
||||||
|
|
||||||
response = stub.Inference(request)
|
response = stub.Inference(request)
|
||||||
|
tts_audio = b''
|
||||||
|
for r in response:
|
||||||
|
tts_audio += r.tts_audio
|
||||||
|
tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
|
||||||
logging.info('save response to {}'.format(args.tts_wav))
|
logging.info('save response to {}'.format(args.tts_wav))
|
||||||
tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0)
|
|
||||||
torchaudio.save(args.tts_wav, tts_speech, target_sr)
|
torchaudio.save(args.tts_wav, tts_speech, target_sr)
|
||||||
logging.info('get response')
|
logging.info('get response')
|
||||||
|
|
||||||
@@ -90,10 +93,11 @@ if __name__ == "__main__":
|
|||||||
default='希望你以后能够做的比我还好呦。')
|
default='希望你以后能够做的比我还好呦。')
|
||||||
parser.add_argument('--prompt_wav',
|
parser.add_argument('--prompt_wav',
|
||||||
type=str,
|
type=str,
|
||||||
default='../../../zero_shot_prompt.wav')
|
default='../../../asset/zero_shot_prompt.wav')
|
||||||
parser.add_argument('--instruct_text',
|
parser.add_argument('--instruct_text',
|
||||||
type=str,
|
type=str,
|
||||||
default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
|
default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
|
||||||
|
Fights with fervor for justice, but struggles with impulsiveness.')
|
||||||
parser.add_argument('--tts_wav',
|
parser.add_argument('--tts_wav',
|
||||||
type=str,
|
type=str,
|
||||||
default='demo.wav')
|
default='demo.wav')
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ package cosyvoice;
|
|||||||
option go_package = "protos/";
|
option go_package = "protos/";
|
||||||
|
|
||||||
service CosyVoice{
|
service CosyVoice{
|
||||||
rpc Inference(Request) returns (Response) {}
|
rpc Inference(Request) returns (stream Response) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
message Request{
|
message Request{
|
||||||
|
|||||||
@@ -13,9 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
|
||||||
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import argparse
|
import argparse
|
||||||
import cosyvoice_pb2
|
import cosyvoice_pb2
|
||||||
@@ -25,14 +22,18 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|||||||
import grpc
|
import grpc
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
||||||
|
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||||
|
from cosyvoice.cli.cosyvoice import AutoModel
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
format='%(asctime)s %(levelname)s %(message)s')
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
|
||||||
|
|
||||||
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.cosyvoice = CosyVoice(args.model_dir)
|
self.cosyvoice = AutoModel(model_dir=args.model_dir)
|
||||||
logging.info('grpc service initialized')
|
logging.info('grpc service initialized')
|
||||||
|
|
||||||
def Inference(self, request, context):
|
def Inference(self, request, context):
|
||||||
@@ -43,7 +44,9 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
|||||||
logging.info('get zero_shot inference request')
|
logging.info('get zero_shot inference request')
|
||||||
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
||||||
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
||||||
model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, request.zero_shot_request.prompt_text, prompt_speech_16k)
|
model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
|
||||||
|
request.zero_shot_request.prompt_text,
|
||||||
|
prompt_speech_16k)
|
||||||
elif request.HasField('cross_lingual_request'):
|
elif request.HasField('cross_lingual_request'):
|
||||||
logging.info('get cross_lingual inference request')
|
logging.info('get cross_lingual inference request')
|
||||||
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
||||||
@@ -51,12 +54,16 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
|||||||
model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
|
model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
|
||||||
else:
|
else:
|
||||||
logging.info('get instruct inference request')
|
logging.info('get instruct inference request')
|
||||||
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
|
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
|
||||||
|
request.instruct_request.spk_id,
|
||||||
|
request.instruct_request.instruct_text)
|
||||||
|
|
||||||
logging.info('send inference response')
|
logging.info('send inference response')
|
||||||
response = cosyvoice_pb2.Response()
|
for i in model_output:
|
||||||
response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
response = cosyvoice_pb2.Response()
|
||||||
return response
|
response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
||||||
|
yield response
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
|
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
|
||||||
@@ -77,7 +84,7 @@ if __name__ == '__main__':
|
|||||||
default=4)
|
default=4)
|
||||||
parser.add_argument('--model_dir',
|
parser.add_argument('--model_dir',
|
||||||
type=str,
|
type=str,
|
||||||
default='iic/CosyVoice-300M',
|
default='iic/CosyVoice2-0.5B',
|
||||||
help='local path or modelscope repo id')
|
help='local path or modelscope repo id')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main()
|
main()
|
||||||
|
|||||||
8
runtime/triton_trtllm/Dockerfile.server
Normal file
8
runtime/triton_trtllm/Dockerfile.server
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
FROM nvcr.io/nvidia/tritonserver:25.06-trtllm-python-py3
|
||||||
|
LABEL maintainer="zhangyuekai@foxmail.com"
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y cmake
|
||||||
|
RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
|
||||||
|
COPY ./requirements.txt /workspace/requirements.txt
|
||||||
|
RUN pip install -r /workspace/requirements.txt
|
||||||
|
WORKDIR /workspace
|
||||||
141
runtime/triton_trtllm/README.DIT.md
Normal file
141
runtime/triton_trtllm/README.DIT.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
## Accelerating CosyVoice with DiT-based Token2Wav, NVIDIA Triton Inference Server and TensorRT-LLM
|
||||||
|
|
||||||
|
Contributed by Yuekai Zhang (NVIDIA).
|
||||||
|
|
||||||
|
This document describes how to accelerate CosyVoice with a DiT-based Token2Wav module from Step-Audio2, using NVIDIA Triton Inference Server and TensorRT-LLM.
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
Launch the service directly with Docker Compose:
|
||||||
|
```sh
|
||||||
|
docker compose -f docker-compose.dit.yml up
|
||||||
|
```
|
||||||
|
|
||||||
|
### Build the Docker Image
|
||||||
|
|
||||||
|
To build the image from scratch:
|
||||||
|
```sh
|
||||||
|
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run a Docker Container
|
||||||
|
```sh
|
||||||
|
your_mount_dir=/mnt:/mnt
|
||||||
|
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
|
||||||
|
```
|
||||||
|
|
||||||
|
### Understanding `run_stepaudio2_dit_token2wav.sh`
|
||||||
|
|
||||||
|
The `run_stepaudio2_dit_token2wav.sh` script orchestrates the entire workflow through numbered stages.
|
||||||
|
|
||||||
|
You can run a subset of stages with:
|
||||||
|
```sh
|
||||||
|
bash run_stepaudio2_dit_token2wav.sh <start_stage> <stop_stage>
|
||||||
|
```
|
||||||
|
- `<start_stage>`: The stage to start from.
|
||||||
|
- `<stop_stage>`: The stage to stop after.
|
||||||
|
|
||||||
|
**Stages:**
|
||||||
|
|
||||||
|
- **Stage -1**: Clones the `Step-Audio2` and `CosyVoice` repositories.
|
||||||
|
- **Stage 0**: Downloads the `cosyvoice2_llm`, `CosyVoice2-0.5B`, and `Step-Audio-2-mini` models.
|
||||||
|
- **Stage 1**: Converts the HuggingFace checkpoint for the LLM to the TensorRT-LLM format and builds the TensorRT engines.
|
||||||
|
- **Stage 2**: Creates the Triton model repository, including configurations for `cosyvoice2_dit` and `token2wav_dit`.
|
||||||
|
- **Stage 3**: Launches the Triton Inference Server for Token2Wav module and uses `trtllm-serve` to deploy Cosyvoice2 LLM.
|
||||||
|
- **Stage 4**: Runs the gRPC benchmark client for performance testing.
|
||||||
|
- **Stage 5**: Runs the offline TTS inference benchmark test.
|
||||||
|
- **Stage 6**: Runs a standalone inference script for the Step-Audio2-mini DiT Token2Wav model.
|
||||||
|
- **Stage 7**: Launches servers in a disaggregated setup, with the LLM on GPU 0 and Token2Wav servers on GPUs 1-3.
|
||||||
|
- **Stage 8**: Runs the benchmark client for the disaggregated server configuration.
|
||||||
|
### Export Models and Launch Server
|
||||||
|
|
||||||
|
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
|
||||||
|
```sh
|
||||||
|
# This command runs stages 0, 1, 2, and 3
|
||||||
|
bash run_stepaudio2_dit_token2wav.sh 0 3
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark with client-server mode
|
||||||
|
|
||||||
|
To benchmark the running Triton server, run stage 4:
|
||||||
|
```sh
|
||||||
|
bash run_stepaudio2_dit_token2wav.sh 4 4
|
||||||
|
|
||||||
|
# You can customize parameters such as the number of tasks inside the script.
|
||||||
|
```
|
||||||
|
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
|
||||||
|
|
||||||
|
#### Total Request Latency
|
||||||
|
|
||||||
|
| Concurrent Tasks | RTF | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
|
||||||
|
| ---------------- | ------ | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
|
||||||
|
| 1 | 0.1228 | 833.66 | 779.98 | 1297.05 | 1555.97 | 1653.02 |
|
||||||
|
| 2 | 0.0901 | 1166.23 | 1124.69 | 1762.76 | 1900.64 | 2204.14 |
|
||||||
|
| 4 | 0.0741 | 1849.30 | 1759.42 | 2624.50 | 2822.20 | 3128.42 |
|
||||||
|
| 6 | 0.0774 | 2936.13 | 3054.64 | 3849.60 | 3900.49 | 4245.79 |
|
||||||
|
| 8 | 0.0691 | 3408.56 | 3434.98 | 4547.13 | 5047.76 | 5346.53 |
|
||||||
|
| 10 | 0.0707 | 4306.56 | 4343.44 | 5769.64 | 5876.09 | 5939.79 |
|
||||||
|
|
||||||
|
#### First Chunk Latency
|
||||||
|
|
||||||
|
| Concurrent Tasks | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
|
||||||
|
| ---------------- | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
|
||||||
|
| 1 | 197.50 | 196.13 | 214.65 | 215.96 | 229.21 |
|
||||||
|
| 2 | 281.15 | 278.20 | 345.18 | 361.79 | 395.97 |
|
||||||
|
| 4 | 510.65 | 530.50 | 630.13 | 642.44 | 666.65 |
|
||||||
|
| 6 | 921.54 | 918.86 | 1079.97 | 1265.22 | 1524.41 |
|
||||||
|
| 8 | 1019.95 | 1085.26 | 1371.05 | 1402.24 | 1410.66 |
|
||||||
|
| 10 | 1214.98 | 1293.54 | 1575.36 | 1654.51 | 2161.76 |
|
||||||
|
|
||||||
|
### Benchmark with offline inference mode
|
||||||
|
For offline inference mode benchmark, please run stage 5:
|
||||||
|
```sh
|
||||||
|
bash run_stepaudio2_dit_token2wav.sh 5 5
|
||||||
|
```
|
||||||
|
|
||||||
|
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
|
||||||
|
|
||||||
|
#### Offline TTS (Cosyvoice2 0.5B LLM + StepAudio2 DiT Token2Wav)
|
||||||
|
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|
||||||
|
|---------|------------|------------------|-----------------------|--|
|
||||||
|
| TRTLLM | 16 | 2.01 | 5.03 | 0.0292 |
|
||||||
|
|
||||||
|
|
||||||
|
### Disaggregated Server
|
||||||
|
When the LLM and token2wav components are deployed on the same GPU, they compete for resources. To optimize performance, we use a disaggregated setup where the LLM is deployed on one dedicated L20 GPU, taking advantage of in-flight batching for inference. The token2wav module is deployed on separate, dedicated GPUs.
|
||||||
|
|
||||||
|
The table below shows the first chunk latency results for this configuration. In our tests, we deploy two token2wav instances on each dedicated token2wav GPU.
|
||||||
|
|
||||||
|
| token2wav_num_gpu | concurrent_task_per_instance | concurrent_tasks_per_gpu | avg (ms) | p50 (ms) | p90 (ms) | p99 (ms) |
|
||||||
|
|---|---|---|---|---|---|---|
|
||||||
|
| 1 | 1 | 1.00 | 218.53 | 217.86 | 254.07 | 296.49 |
|
||||||
|
| 2 | 1 | 1.33 | 218.82 | 219.21 | 256.62 | 303.13 |
|
||||||
|
| 3 | 1 | 1.50 | 229.08 | 223.27 | 302.13 | 324.41 |
|
||||||
|
| 4 | 1 | 1.60 | 203.87 | 198.23 | 254.92 | 279.31 |
|
||||||
|
| 1 | 2 | 2.00 | 293.46 | 280.53 | 370.81 | 407.40 |
|
||||||
|
| 2 | 2 | 2.67 | 263.38 | 236.84 | 350.82 | 397.39 |
|
||||||
|
| 3 | 2 | 3.00 | 308.09 | 275.48 | 385.22 | 521.45 |
|
||||||
|
| 4 | 2 | 3.20 | 271.85 | 253.25 | 359.03 | 387.91 |
|
||||||
|
| 1 | 3 | 3.00 | 389.15 | 373.01 | 469.22 | 542.89 |
|
||||||
|
| 2 | 3 | 4.00 | 403.48 | 394.80 | 481.24 | 507.75 |
|
||||||
|
| 3 | 3 | 4.50 | 406.33 | 391.28 | 495.43 | 571.29 |
|
||||||
|
| 4 | 3 | 4.80 | 436.72 | 383.81 | 638.44 | 879.23 |
|
||||||
|
| 1 | 4 | 4.00 | 520.12 | 493.98 | 610.38 | 739.85 |
|
||||||
|
| 2 | 4 | 5.33 | 494.60 | 490.50 | 605.93 | 708.09 |
|
||||||
|
| 3 | 4 | 6.00 | 538.23 | 508.33 | 687.62 | 736.96 |
|
||||||
|
| 4 | 4 | 6.40 | 579.68 | 546.20 | 721.53 | 958.04 |
|
||||||
|
| 1 | 5 | 5.00 | 635.02 | 623.30 | 786.85 | 819.84 |
|
||||||
|
| 2 | 5 | 6.67 | 598.23 | 617.09 | 741.00 | 788.96 |
|
||||||
|
| 3 | 5 | 7.50 | 644.78 | 684.40 | 786.45 | 1009.45 |
|
||||||
|
| 4 | 5 | 8.00 | 733.92 | 642.26 | 1024.79 | 1281.55 |
|
||||||
|
| 1 | 6 | 6.00 | 715.38 | 745.68 | 887.04 | 906.68 |
|
||||||
|
| 2 | 6 | 8.00 | 748.31 | 753.94 | 873.59 | 1007.14 |
|
||||||
|
| 3 | 6 | 9.00 | 900.27 | 822.28 | 1431.14 | 1800.23 |
|
||||||
|
| 4 | 6 | 9.60 | 857.54 | 820.33 | 1150.30 | 1298.53 |
|
||||||
|
|
||||||
|
The `concurrent_task_per_gpu` is calculated as:
|
||||||
|
`concurrent_task_per_gpu = concurrent_task_per_instance * num_token2wav_instance_per_gpu (2) * token2wav_gpus / (token2wav_gpus + llm_gpus (1))`
|
||||||
|
|
||||||
|
### Acknowledgements
|
||||||
|
|
||||||
|
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).
|
||||||
146
runtime/triton_trtllm/README.md
Normal file
146
runtime/triton_trtllm/README.md
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM
|
||||||
|
|
||||||
|
Contributed by Yuekai Zhang (NVIDIA).
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
Launch the service directly with Docker Compose:
|
||||||
|
```sh
|
||||||
|
docker compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
### Build the Docker Image
|
||||||
|
|
||||||
|
To build the image from scratch:
|
||||||
|
```sh
|
||||||
|
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run a Docker Container
|
||||||
|
```sh
|
||||||
|
your_mount_dir=/mnt:/mnt
|
||||||
|
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
|
||||||
|
```
|
||||||
|
|
||||||
|
### Understanding `run.sh`
|
||||||
|
|
||||||
|
The `run.sh` script orchestrates the entire workflow through numbered stages.
|
||||||
|
|
||||||
|
You can run a subset of stages with:
|
||||||
|
```sh
|
||||||
|
bash run.sh <start_stage> <stop_stage> [service_type]
|
||||||
|
```
|
||||||
|
- `<start_stage>`: The stage to start from (0-5).
|
||||||
|
- `<stop_stage>`: The stage to stop after (0-5).
|
||||||
|
|
||||||
|
**Stages:**
|
||||||
|
|
||||||
|
- **Stage 0**: Downloads the `cosyvoice-2 0.5B` model from HuggingFace.
|
||||||
|
- **Stage 1**: Converts the HuggingFace checkpoint to the TensorRT-LLM format and builds the TensorRT engines.
|
||||||
|
- **Stage 2**: Creates the Triton model repository and configures the model files. The configuration is adjusted based on whether `Decoupled=True` (streaming) or `Decoupled=False` (offline) will be used.
|
||||||
|
- **Stage 3**: Launches the Triton Inference Server.
|
||||||
|
- **Stage 4**: Runs the single-utterance HTTP client for testing.
|
||||||
|
- **Stage 5**: Runs the gRPC benchmark client.
|
||||||
|
- **Stage 6**: Runs the offline inference benchmark test.
|
||||||
|
|
||||||
|
### Export Models and Launch Server
|
||||||
|
|
||||||
|
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
|
||||||
|
```sh
|
||||||
|
# This command runs stages 0, 1, 2, and 3
|
||||||
|
bash run.sh 0 3
|
||||||
|
```
|
||||||
|
> [!TIP]
|
||||||
|
> Both streaming and offline (non-streaming) TTS modes are supported. For streaming TTS, set `Decoupled=True`. For offline TTS, set `Decoupled=False`. You need to rerun stage 2 if you switch between modes.
|
||||||
|
|
||||||
|
### Single-Utterance HTTP Client
|
||||||
|
|
||||||
|
Sends a single HTTP inference request. This is intended for testing the offline TTS mode (`Decoupled=False`):
|
||||||
|
```sh
|
||||||
|
bash run.sh 4 4
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark with client-server mode
|
||||||
|
|
||||||
|
To benchmark the running Triton server, pass `streaming` or `offline` as the third argument:
|
||||||
|
```sh
|
||||||
|
bash run.sh 5 5 # [streaming|offline]
|
||||||
|
|
||||||
|
# You can also customize parameters such as the number of tasks and the dataset split:
|
||||||
|
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
|
||||||
|
```
|
||||||
|
> [!TIP]
|
||||||
|
> It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up.
|
||||||
|
|
||||||
|
### Benchmark with offline inference mode
|
||||||
|
For offline inference mode benchmark, please check the below command:
|
||||||
|
```sh
|
||||||
|
# install FlashCosyVoice for token2wav batching
|
||||||
|
# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
|
||||||
|
# cd /workspace/FlashCosyVoice
|
||||||
|
# pip install -e .
|
||||||
|
# cd -
|
||||||
|
# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
|
||||||
|
|
||||||
|
bash run.sh 6 6
|
||||||
|
|
||||||
|
# You can also switch to huggingface backend by setting backend=hf
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark Results
|
||||||
|
The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio):
|
||||||
|
|
||||||
|
**Client-Server Mode: Streaming TTS (First Chunk Latency)**
|
||||||
|
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
|
||||||
|
| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
|
||||||
|
| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
|
||||||
|
| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
|
||||||
|
| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
|
||||||
|
| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
|
||||||
|
|
||||||
|
> If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E).
|
||||||
|
|
||||||
|
**Client-Server Mode: Offline TTS (Full Sentence Latency)**
|
||||||
|
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|
||||||
|
|---|---|---|---|---|---|
|
||||||
|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
|
||||||
|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
|
||||||
|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
|
||||||
|
|
||||||
|
**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM**
|
||||||
|
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|
||||||
|
|---------|------------|------------------|-----------------------|--|
|
||||||
|
| HF | 1 | 39.26 | 44.31 | 0.2494 |
|
||||||
|
| HF | 2 | 30.54 | 35.62 | 0.2064 |
|
||||||
|
| HF | 4 | 18.63 | 23.90 | 0.1421 |
|
||||||
|
| HF | 8 | 11.22 | 16.45 | 0.0947 |
|
||||||
|
| HF | 16 | 8.42 | 13.78 | 0.0821 |
|
||||||
|
| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
|
||||||
|
| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
|
||||||
|
| TRTLLM | 4 | 4.89 | 9.38 | 0.0539 |
|
||||||
|
| TRTLLM | 8 | 2.92 | 7.23 | 0.0418 |
|
||||||
|
| TRTLLM | 16 | 2.01 | 6.63 | 0.0386 |
|
||||||
|
### OpenAI-Compatible Server
|
||||||
|
|
||||||
|
To launch an OpenAI-compatible API service, run the following commands:
|
||||||
|
```sh
|
||||||
|
git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
|
||||||
|
cd Triton-OpenAI-Speech
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# After the Triton service is running, start the FastAPI bridge:
|
||||||
|
python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
|
||||||
|
|
||||||
|
# Test the service with curl:
|
||||||
|
bash test/test_cosyvoice.sh
|
||||||
|
```
|
||||||
|
> [!NOTE]
|
||||||
|
> Currently, only the offline TTS mode is compatible with the OpenAI-compatible server.
|
||||||
|
|
||||||
|
### Acknowledgements
|
||||||
|
|
||||||
|
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).
|
||||||
|
|
||||||
922
runtime/triton_trtllm/client_grpc.py
Normal file
922
runtime/triton_trtllm/client_grpc.py
Normal file
@@ -0,0 +1,922 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
# 2023 Nvidia (authors: Yuekai Zhang)
|
||||||
|
# 2023 Recurrent.ai (authors: Songtao Shi)
|
||||||
|
# See LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script supports to load dataset from huggingface and sends it to the server
|
||||||
|
for decoding, in parallel.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
num_task=2
|
||||||
|
|
||||||
|
# For offline F5-TTS
|
||||||
|
python3 client_grpc.py \
|
||||||
|
--server-addr localhost \
|
||||||
|
--model-name f5_tts \
|
||||||
|
--num-tasks $num_task \
|
||||||
|
--huggingface-dataset yuekai/seed_tts \
|
||||||
|
--split-name test_zh \
|
||||||
|
--log-dir ./log_concurrent_tasks_${num_task}
|
||||||
|
|
||||||
|
# For offline Spark-TTS-0.5B
|
||||||
|
python3 client_grpc.py \
|
||||||
|
--server-addr localhost \
|
||||||
|
--model-name spark_tts \
|
||||||
|
--num-tasks $num_task \
|
||||||
|
--huggingface-dataset yuekai/seed_tts \
|
||||||
|
--split-name wenetspeech4tts \
|
||||||
|
--log-dir ./log_concurrent_tasks_${num_task}
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
import uuid
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import tritonclient
|
||||||
|
import tritonclient.grpc.aio as grpcclient_aio
|
||||||
|
import tritonclient.grpc as grpcclient_sync
|
||||||
|
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
|
||||||
|
|
||||||
|
|
||||||
|
class UserData:
|
||||||
|
def __init__(self):
|
||||||
|
self._completed_requests = queue.Queue()
|
||||||
|
self._first_chunk_time = None
|
||||||
|
self._second_chunk_time = None
|
||||||
|
self._start_time = None
|
||||||
|
|
||||||
|
def record_start_time(self):
|
||||||
|
self._start_time = time.time()
|
||||||
|
|
||||||
|
def get_first_chunk_latency(self):
|
||||||
|
if self._first_chunk_time and self._start_time:
|
||||||
|
return self._first_chunk_time - self._start_time
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_second_chunk_latency(self):
|
||||||
|
if self._first_chunk_time and self._second_chunk_time:
|
||||||
|
return self._second_chunk_time - self._first_chunk_time
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def callback(user_data, result, error):
|
||||||
|
if not error:
|
||||||
|
if user_data._first_chunk_time is None:
|
||||||
|
user_data._first_chunk_time = time.time()
|
||||||
|
elif user_data._second_chunk_time is None:
|
||||||
|
user_data._second_chunk_time = time.time()
|
||||||
|
|
||||||
|
if error:
|
||||||
|
user_data._completed_requests.put(error)
|
||||||
|
else:
|
||||||
|
user_data._completed_requests.put(result)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_callback(user_data_map, result, error):
|
||||||
|
request_id = None
|
||||||
|
if error:
|
||||||
|
print(f"An error occurred in the stream callback: {error}")
|
||||||
|
else:
|
||||||
|
request_id = result.get_response().id
|
||||||
|
|
||||||
|
if request_id:
|
||||||
|
user_data = user_data_map.get(request_id)
|
||||||
|
if user_data:
|
||||||
|
callback(user_data, result, error)
|
||||||
|
else:
|
||||||
|
print(f"Warning: Could not find user_data for request_id {request_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def write_triton_stats(stats, summary_file):
|
||||||
|
with open(summary_file, "w") as summary_f:
|
||||||
|
model_stats = stats["model_stats"]
|
||||||
|
for model_state in model_stats:
|
||||||
|
if "last_inference" not in model_state:
|
||||||
|
continue
|
||||||
|
summary_f.write(f"model name is {model_state['name']} \n")
|
||||||
|
model_inference_stats = model_state["inference_stats"]
|
||||||
|
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
|
||||||
|
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
|
||||||
|
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
||||||
|
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
||||||
|
summary_f.write(
|
||||||
|
f"queue time {total_queue_time_s:<5.2f} s, "
|
||||||
|
f"compute infer time {total_infer_time_s:<5.2f} s, "
|
||||||
|
f"compute input time {total_input_time_s:<5.2f} s, "
|
||||||
|
f"compute output time {total_output_time_s:<5.2f} s \n"
|
||||||
|
)
|
||||||
|
model_batch_stats = model_state["batch_stats"]
|
||||||
|
for batch in model_batch_stats:
|
||||||
|
batch_size = int(batch["batch_size"])
|
||||||
|
compute_input = batch["compute_input"]
|
||||||
|
compute_output = batch["compute_output"]
|
||||||
|
compute_infer = batch["compute_infer"]
|
||||||
|
batch_count = int(compute_infer["count"])
|
||||||
|
if batch_count == 0:
|
||||||
|
continue
|
||||||
|
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
|
||||||
|
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
||||||
|
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
||||||
|
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
||||||
|
summary_f.write(
|
||||||
|
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
|
||||||
|
f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
|
||||||
|
f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
|
||||||
|
f"{compute_infer_time_ms / batch_count:.2f} ms, "
|
||||||
|
f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
|
||||||
|
f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
|
||||||
|
)
|
||||||
|
summary_f.write(
|
||||||
|
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
|
||||||
|
)
|
||||||
|
summary_f.write(
|
||||||
|
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def subtract_stats(stats_after, stats_before):
|
||||||
|
"""Subtracts two Triton inference statistics objects."""
|
||||||
|
stats_diff = json.loads(json.dumps(stats_after))
|
||||||
|
|
||||||
|
model_stats_before_map = {
|
||||||
|
s["name"]: {
|
||||||
|
"version": s["version"],
|
||||||
|
"last_inference": s.get("last_inference", 0),
|
||||||
|
"inference_count": s.get("inference_count", 0),
|
||||||
|
"execution_count": s.get("execution_count", 0),
|
||||||
|
"inference_stats": s.get("inference_stats", {}),
|
||||||
|
"batch_stats": s.get("batch_stats", []),
|
||||||
|
}
|
||||||
|
for s in stats_before["model_stats"]
|
||||||
|
}
|
||||||
|
|
||||||
|
for model_stat_after in stats_diff["model_stats"]:
|
||||||
|
model_name = model_stat_after["name"]
|
||||||
|
if model_name in model_stats_before_map:
|
||||||
|
model_stat_before = model_stats_before_map[model_name]
|
||||||
|
|
||||||
|
model_stat_after["inference_count"] = str(
|
||||||
|
int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
|
||||||
|
)
|
||||||
|
model_stat_after["execution_count"] = str(
|
||||||
|
int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
|
||||||
|
)
|
||||||
|
|
||||||
|
if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
|
||||||
|
for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
|
||||||
|
if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
|
||||||
|
if "ns" in model_stat_after["inference_stats"][key]:
|
||||||
|
ns_after = int(model_stat_after["inference_stats"][key]["ns"])
|
||||||
|
ns_before = int(model_stat_before["inference_stats"][key]["ns"])
|
||||||
|
model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
|
||||||
|
if "count" in model_stat_after["inference_stats"][key]:
|
||||||
|
count_after = int(model_stat_after["inference_stats"][key]["count"])
|
||||||
|
count_before = int(model_stat_before["inference_stats"][key]["count"])
|
||||||
|
model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
|
||||||
|
|
||||||
|
if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
|
||||||
|
batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
|
||||||
|
for batch_stat_after in model_stat_after["batch_stats"]:
|
||||||
|
bs = batch_stat_after["batch_size"]
|
||||||
|
if bs in batch_stats_before_map:
|
||||||
|
batch_stat_before = batch_stats_before_map[bs]
|
||||||
|
for key in ["compute_input", "compute_infer", "compute_output"]:
|
||||||
|
if key in batch_stat_after and key in batch_stat_before:
|
||||||
|
count_after = int(batch_stat_after[key]["count"])
|
||||||
|
count_before = int(batch_stat_before[key]["count"])
|
||||||
|
batch_stat_after[key]["count"] = str(count_after - count_before)
|
||||||
|
|
||||||
|
ns_after = int(batch_stat_after[key]["ns"])
|
||||||
|
ns_before = int(batch_stat_before[key]["ns"])
|
||||||
|
batch_stat_after[key]["ns"] = str(ns_after - ns_before)
|
||||||
|
return stats_diff
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-addr",
|
||||||
|
type=str,
|
||||||
|
default="localhost",
|
||||||
|
help="Address of the server",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-port",
|
||||||
|
type=int,
|
||||||
|
default=8001,
|
||||||
|
help="Grpc port of the triton server, default is 8001",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--reference-audio",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--reference-text",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-text",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--huggingface-dataset",
|
||||||
|
type=str,
|
||||||
|
default="yuekai/seed_tts",
|
||||||
|
help="dataset name in huggingface dataset hub",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||||
|
help="dataset split name, default is 'test'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the manifest dir which includes wav.scp trans.txt files.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="f5_tts",
|
||||||
|
choices=[
|
||||||
|
"f5_tts",
|
||||||
|
"spark_tts",
|
||||||
|
"cosyvoice2",
|
||||||
|
"cosyvoice2_dit"],
|
||||||
|
help="triton model_repo module name to request",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-tasks",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of concurrent tasks for sending",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-interval",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Controls how frequently we print the log.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--compute-wer",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="""True to compute WER.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-dir",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="./tmp",
|
||||||
|
help="log directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
default="offline",
|
||||||
|
choices=["offline", "streaming"],
|
||||||
|
help="Select offline or streaming benchmark mode."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-overlap-duration",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-spk2info-cache",
|
||||||
|
type=str,
|
||||||
|
default="False",
|
||||||
|
help="Use spk2info cache for reference audio.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(wav_path, target_sample_rate=16000):
|
||||||
|
assert target_sample_rate == 16000, "hard coding in server"
|
||||||
|
if isinstance(wav_path, dict):
|
||||||
|
waveform = wav_path["array"]
|
||||||
|
sample_rate = wav_path["sampling_rate"]
|
||||||
|
else:
|
||||||
|
waveform, sample_rate = sf.read(wav_path)
|
||||||
|
if sample_rate != target_sample_rate:
|
||||||
|
from scipy.signal import resample
|
||||||
|
|
||||||
|
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
||||||
|
waveform = resample(waveform, num_samples)
|
||||||
|
return waveform, target_sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_request_input_output(
|
||||||
|
protocol_client,
|
||||||
|
waveform,
|
||||||
|
reference_text,
|
||||||
|
target_text,
|
||||||
|
sample_rate=16000,
|
||||||
|
padding_duration: int = None,
|
||||||
|
use_spk2info_cache: bool = False
|
||||||
|
):
|
||||||
|
"""Prepares inputs for Triton inference (offline or streaming)."""
|
||||||
|
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||||
|
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||||
|
|
||||||
|
if padding_duration:
|
||||||
|
duration = len(waveform) / sample_rate
|
||||||
|
if reference_text:
|
||||||
|
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
||||||
|
else:
|
||||||
|
estimated_target_duration = duration
|
||||||
|
|
||||||
|
required_total_samples = padding_duration * sample_rate * (
|
||||||
|
(int(estimated_target_duration + duration) // padding_duration) + 1
|
||||||
|
)
|
||||||
|
samples = np.zeros((1, required_total_samples), dtype=np.float32)
|
||||||
|
samples[0, : len(waveform)] = waveform
|
||||||
|
else:
|
||||||
|
samples = waveform.reshape(1, -1).astype(np.float32)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
|
||||||
|
protocol_client.InferInput(
|
||||||
|
"reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
|
||||||
|
),
|
||||||
|
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
|
||||||
|
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
|
||||||
|
]
|
||||||
|
inputs[0].set_data_from_numpy(samples)
|
||||||
|
inputs[1].set_data_from_numpy(lengths)
|
||||||
|
|
||||||
|
input_data_numpy = np.array([reference_text], dtype=object)
|
||||||
|
input_data_numpy = input_data_numpy.reshape((1, 1))
|
||||||
|
inputs[2].set_data_from_numpy(input_data_numpy)
|
||||||
|
|
||||||
|
input_data_numpy = np.array([target_text], dtype=object)
|
||||||
|
input_data_numpy = input_data_numpy.reshape((1, 1))
|
||||||
|
inputs[3].set_data_from_numpy(input_data_numpy)
|
||||||
|
|
||||||
|
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
||||||
|
if use_spk2info_cache:
|
||||||
|
inputs = inputs[-1:]
|
||||||
|
return inputs, outputs
|
||||||
|
|
||||||
|
|
||||||
|
def run_sync_streaming_inference(
|
||||||
|
sync_triton_client: tritonclient.grpc.InferenceServerClient,
|
||||||
|
model_name: str,
|
||||||
|
inputs: list,
|
||||||
|
outputs: list,
|
||||||
|
request_id: str,
|
||||||
|
user_data: UserData,
|
||||||
|
chunk_overlap_duration: float,
|
||||||
|
save_sample_rate: int,
|
||||||
|
audio_save_path: str,
|
||||||
|
):
|
||||||
|
"""Helper function to run the blocking sync streaming call."""
|
||||||
|
start_time_total = time.time()
|
||||||
|
user_data.record_start_time()
|
||||||
|
|
||||||
|
sync_triton_client.async_stream_infer(
|
||||||
|
model_name,
|
||||||
|
inputs,
|
||||||
|
request_id=request_id,
|
||||||
|
outputs=outputs,
|
||||||
|
enable_empty_final_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
audios = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result = user_data._completed_requests.get(timeout=200)
|
||||||
|
if isinstance(result, InferenceServerException):
|
||||||
|
print(f"Received InferenceServerException: {result}")
|
||||||
|
return None, None, None, None
|
||||||
|
response = result.get_response()
|
||||||
|
final = response.parameters["triton_final_response"].bool_param
|
||||||
|
if final is True:
|
||||||
|
break
|
||||||
|
|
||||||
|
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
||||||
|
if audio_chunk.size > 0:
|
||||||
|
audios.append(audio_chunk)
|
||||||
|
else:
|
||||||
|
print("Warning: received empty audio chunk.")
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
print(f"Timeout waiting for response for request id {request_id}")
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
end_time_total = time.time()
|
||||||
|
total_request_latency = end_time_total - start_time_total
|
||||||
|
first_chunk_latency = user_data.get_first_chunk_latency()
|
||||||
|
second_chunk_latency = user_data.get_second_chunk_latency()
|
||||||
|
|
||||||
|
if audios:
|
||||||
|
if model_name == "spark_tts":
|
||||||
|
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
|
||||||
|
fade_out = np.linspace(1, 0, cross_fade_samples)
|
||||||
|
fade_in = np.linspace(0, 1, cross_fade_samples)
|
||||||
|
reconstructed_audio = None
|
||||||
|
|
||||||
|
if not audios:
|
||||||
|
print("Warning: No audio chunks received.")
|
||||||
|
reconstructed_audio = np.array([], dtype=np.float32)
|
||||||
|
elif len(audios) == 1:
|
||||||
|
reconstructed_audio = audios[0]
|
||||||
|
else:
|
||||||
|
reconstructed_audio = audios[0][:-cross_fade_samples]
|
||||||
|
for i in range(1, len(audios)):
|
||||||
|
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
||||||
|
audios[i - 1][-cross_fade_samples:] * fade_out)
|
||||||
|
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
||||||
|
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
||||||
|
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
||||||
|
|
||||||
|
if reconstructed_audio is not None and reconstructed_audio.size > 0:
|
||||||
|
actual_duration = len(reconstructed_audio) / save_sample_rate
|
||||||
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
||||||
|
else:
|
||||||
|
print("Warning: No audio chunks received or reconstructed.")
|
||||||
|
actual_duration = 0
|
||||||
|
else:
|
||||||
|
reconstructed_audio = np.concatenate(audios)
|
||||||
|
actual_duration = len(reconstructed_audio) / save_sample_rate
|
||||||
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Warning: No audio chunks received.")
|
||||||
|
actual_duration = 0
|
||||||
|
|
||||||
|
return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
|
||||||
|
|
||||||
|
|
||||||
|
async def send_streaming(
|
||||||
|
manifest_item_list: list,
|
||||||
|
name: str,
|
||||||
|
server_url: str,
|
||||||
|
protocol_client: types.ModuleType,
|
||||||
|
log_interval: int,
|
||||||
|
model_name: str,
|
||||||
|
audio_save_dir: str = "./",
|
||||||
|
save_sample_rate: int = 16000,
|
||||||
|
chunk_overlap_duration: float = 0.1,
|
||||||
|
padding_duration: int = None,
|
||||||
|
use_spk2info_cache: bool = False,
|
||||||
|
):
|
||||||
|
total_duration = 0.0
|
||||||
|
latency_data = []
|
||||||
|
task_id = int(name[5:])
|
||||||
|
sync_triton_client = None
|
||||||
|
user_data_map = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"{name}: Initializing sync client for streaming...")
|
||||||
|
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
|
||||||
|
sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
|
||||||
|
|
||||||
|
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
||||||
|
for i, item in enumerate(manifest_item_list):
|
||||||
|
if i % log_interval == 0:
|
||||||
|
print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
||||||
|
reference_text, target_text = item["reference_text"], item["target_text"]
|
||||||
|
|
||||||
|
inputs, outputs = prepare_request_input_output(
|
||||||
|
protocol_client,
|
||||||
|
waveform,
|
||||||
|
reference_text,
|
||||||
|
target_text,
|
||||||
|
sample_rate,
|
||||||
|
padding_duration=padding_duration,
|
||||||
|
use_spk2info_cache=use_spk2info_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
user_data = UserData()
|
||||||
|
user_data_map[request_id] = user_data
|
||||||
|
|
||||||
|
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||||
|
total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
|
||||||
|
run_sync_streaming_inference,
|
||||||
|
sync_triton_client,
|
||||||
|
model_name,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
request_id,
|
||||||
|
user_data,
|
||||||
|
chunk_overlap_duration,
|
||||||
|
save_sample_rate,
|
||||||
|
audio_save_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if total_request_latency is not None:
|
||||||
|
print(
|
||||||
|
f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
|
||||||
|
f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
|
||||||
|
f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
|
||||||
|
)
|
||||||
|
latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
|
||||||
|
total_duration += actual_duration
|
||||||
|
else:
|
||||||
|
print(f"{name}: Item {i} failed.")
|
||||||
|
|
||||||
|
del user_data_map[request_id]
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if sync_triton_client:
|
||||||
|
try:
|
||||||
|
print(f"{name}: Closing stream and sync client...")
|
||||||
|
sync_triton_client.stop_stream()
|
||||||
|
sync_triton_client.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{name}: Error closing sync client: {e}")
|
||||||
|
|
||||||
|
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
|
||||||
|
return total_duration, latency_data
|
||||||
|
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
manifest_item_list: list,
|
||||||
|
name: str,
|
||||||
|
triton_client: tritonclient.grpc.aio.InferenceServerClient,
|
||||||
|
protocol_client: types.ModuleType,
|
||||||
|
log_interval: int,
|
||||||
|
model_name: str,
|
||||||
|
padding_duration: int = None,
|
||||||
|
audio_save_dir: str = "./",
|
||||||
|
save_sample_rate: int = 16000,
|
||||||
|
use_spk2info_cache: bool = False,
|
||||||
|
):
|
||||||
|
total_duration = 0.0
|
||||||
|
latency_data = []
|
||||||
|
task_id = int(name[5:])
|
||||||
|
|
||||||
|
for i, item in enumerate(manifest_item_list):
|
||||||
|
if i % log_interval == 0:
|
||||||
|
print(f"{name}: {i}/{len(manifest_item_list)}")
|
||||||
|
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
||||||
|
reference_text, target_text = item["reference_text"], item["target_text"]
|
||||||
|
|
||||||
|
inputs, outputs = prepare_request_input_output(
|
||||||
|
protocol_client,
|
||||||
|
waveform,
|
||||||
|
reference_text,
|
||||||
|
target_text,
|
||||||
|
sample_rate,
|
||||||
|
padding_duration=padding_duration,
|
||||||
|
use_spk2info_cache=use_spk2info_cache
|
||||||
|
)
|
||||||
|
sequence_id = 100000000 + i + task_id * 10
|
||||||
|
start = time.time()
|
||||||
|
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
|
||||||
|
|
||||||
|
audio = response.as_numpy("waveform").reshape(-1)
|
||||||
|
actual_duration = len(audio) / save_sample_rate
|
||||||
|
|
||||||
|
end = time.time() - start
|
||||||
|
|
||||||
|
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||||
|
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
||||||
|
|
||||||
|
latency_data.append((end, actual_duration))
|
||||||
|
total_duration += actual_duration
|
||||||
|
|
||||||
|
return total_duration, latency_data
|
||||||
|
|
||||||
|
|
||||||
|
def load_manifests(manifest_path):
|
||||||
|
with open(manifest_path, "r") as f:
|
||||||
|
manifest_list = []
|
||||||
|
for line in f:
|
||||||
|
assert len(line.strip().split("|")) == 4
|
||||||
|
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||||
|
utt = Path(utt).stem
|
||||||
|
if not os.path.isabs(prompt_wav):
|
||||||
|
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
||||||
|
manifest_list.append(
|
||||||
|
{
|
||||||
|
"audio_filepath": prompt_wav,
|
||||||
|
"reference_text": prompt_text,
|
||||||
|
"target_text": gt_text,
|
||||||
|
"target_audio_path": utt,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return manifest_list
|
||||||
|
|
||||||
|
|
||||||
|
def split_data(data, k):
|
||||||
|
n = len(data)
|
||||||
|
if n < k:
|
||||||
|
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
|
||||||
|
k = n
|
||||||
|
|
||||||
|
quotient = n // k
|
||||||
|
remainder = n % k
|
||||||
|
|
||||||
|
result = []
|
||||||
|
start = 0
|
||||||
|
for i in range(k):
|
||||||
|
if i < remainder:
|
||||||
|
end = start + quotient + 1
|
||||||
|
else:
|
||||||
|
end = start + quotient
|
||||||
|
|
||||||
|
result.append(data[start:end])
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
args = get_args()
|
||||||
|
url = f"{args.server_addr}:{args.server_port}"
|
||||||
|
|
||||||
|
triton_client = None
|
||||||
|
protocol_client = None
|
||||||
|
if args.mode == "offline":
|
||||||
|
print("Initializing gRPC client for offline mode...")
|
||||||
|
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
||||||
|
protocol_client = grpcclient_aio
|
||||||
|
elif args.mode == "streaming":
|
||||||
|
print("Initializing gRPC client for streaming mode...")
|
||||||
|
protocol_client = grpcclient_sync
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mode: {args.mode}")
|
||||||
|
|
||||||
|
if args.reference_audio:
|
||||||
|
args.num_tasks = 1
|
||||||
|
args.log_interval = 1
|
||||||
|
manifest_item_list = [
|
||||||
|
{
|
||||||
|
"reference_text": args.reference_text,
|
||||||
|
"target_text": args.target_text,
|
||||||
|
"audio_filepath": args.reference_audio,
|
||||||
|
"target_audio_path": "test",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif args.huggingface_dataset:
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
dataset = datasets.load_dataset(
|
||||||
|
args.huggingface_dataset,
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
manifest_item_list = []
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
manifest_item_list.append(
|
||||||
|
{
|
||||||
|
"audio_filepath": dataset[i]["prompt_audio"],
|
||||||
|
"reference_text": dataset[i]["prompt_text"],
|
||||||
|
"target_audio_path": dataset[i]["id"],
|
||||||
|
"target_text": dataset[i]["target_text"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
manifest_item_list = load_manifests(args.manifest_path)
|
||||||
|
|
||||||
|
stats_client = None
|
||||||
|
stats_before = None
|
||||||
|
try:
|
||||||
|
print("Initializing temporary async client for fetching stats...")
|
||||||
|
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
||||||
|
print("Fetching inference statistics before running tasks...")
|
||||||
|
stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not retrieve statistics before running tasks: {e}")
|
||||||
|
|
||||||
|
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
||||||
|
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
||||||
|
|
||||||
|
os.makedirs(args.log_dir, exist_ok=True)
|
||||||
|
args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
|
||||||
|
tasks = []
|
||||||
|
start_time = time.time()
|
||||||
|
for i in range(num_tasks):
|
||||||
|
if args.mode == "offline":
|
||||||
|
task = asyncio.create_task(
|
||||||
|
send(
|
||||||
|
manifest_item_list[i],
|
||||||
|
name=f"task-{i}",
|
||||||
|
triton_client=triton_client,
|
||||||
|
protocol_client=protocol_client,
|
||||||
|
log_interval=args.log_interval,
|
||||||
|
model_name=args.model_name,
|
||||||
|
audio_save_dir=args.log_dir,
|
||||||
|
padding_duration=1,
|
||||||
|
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||||
|
use_spk2info_cache=args.use_spk2info_cache,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif args.mode == "streaming":
|
||||||
|
task = asyncio.create_task(
|
||||||
|
send_streaming(
|
||||||
|
manifest_item_list[i],
|
||||||
|
name=f"task-{i}",
|
||||||
|
server_url=url,
|
||||||
|
protocol_client=protocol_client,
|
||||||
|
log_interval=args.log_interval,
|
||||||
|
model_name=args.model_name,
|
||||||
|
audio_save_dir=args.log_dir,
|
||||||
|
padding_duration=10,
|
||||||
|
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
|
||||||
|
chunk_overlap_duration=args.chunk_overlap_duration,
|
||||||
|
use_spk2info_cache=args.use_spk2info_cache,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
ans_list = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed = end_time - start_time
|
||||||
|
|
||||||
|
total_duration = 0.0
|
||||||
|
latency_data = []
|
||||||
|
for ans in ans_list:
|
||||||
|
if ans:
|
||||||
|
total_duration += ans[0]
|
||||||
|
latency_data.extend(ans[1])
|
||||||
|
else:
|
||||||
|
print("Warning: A task returned None, possibly due to an error.")
|
||||||
|
|
||||||
|
if total_duration == 0:
|
||||||
|
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
|
||||||
|
rtf = float('inf')
|
||||||
|
else:
|
||||||
|
rtf = elapsed / total_duration
|
||||||
|
|
||||||
|
s = f"Mode: {args.mode}\n"
|
||||||
|
s += f"RTF: {rtf:.4f}\n"
|
||||||
|
s += f"total_duration: {total_duration:.3f} seconds\n"
|
||||||
|
s += f"({total_duration / 3600:.2f} hours)\n"
|
||||||
|
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
|
||||||
|
|
||||||
|
if latency_data:
|
||||||
|
if args.mode == "offline":
|
||||||
|
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
||||||
|
if latency_list:
|
||||||
|
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
||||||
|
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
|
||||||
|
s += f"latency_variance: {latency_variance:.2f}\n"
|
||||||
|
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
|
||||||
|
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
|
||||||
|
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
|
||||||
|
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
|
||||||
|
s += f"average_latency_ms: {latency_ms:.2f}\n"
|
||||||
|
else:
|
||||||
|
s += "No latency data collected for offline mode.\n"
|
||||||
|
|
||||||
|
elif args.mode == "streaming":
|
||||||
|
total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
|
||||||
|
first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
|
||||||
|
second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
|
||||||
|
|
||||||
|
s += "\n--- Total Request Latency ---\n"
|
||||||
|
if total_latency_list:
|
||||||
|
avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
|
||||||
|
variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
|
||||||
|
s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
|
||||||
|
s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
|
||||||
|
s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
|
||||||
|
s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
|
||||||
|
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
|
||||||
|
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
|
||||||
|
else:
|
||||||
|
s += "No total request latency data collected.\n"
|
||||||
|
|
||||||
|
s += "\n--- First Chunk Latency ---\n"
|
||||||
|
if first_chunk_latency_list:
|
||||||
|
avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
|
||||||
|
variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
|
||||||
|
s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
|
||||||
|
s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
|
||||||
|
s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
|
||||||
|
s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
|
||||||
|
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
||||||
|
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
||||||
|
else:
|
||||||
|
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
||||||
|
|
||||||
|
s += "\n--- Second Chunk Latency ---\n"
|
||||||
|
if second_chunk_latency_list:
|
||||||
|
avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
|
||||||
|
variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
|
||||||
|
s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
|
||||||
|
s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
||||||
|
s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
|
||||||
|
else:
|
||||||
|
s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
|
||||||
|
else:
|
||||||
|
s += "No latency data collected.\n"
|
||||||
|
|
||||||
|
print(s)
|
||||||
|
if args.manifest_path:
|
||||||
|
name = Path(args.manifest_path).stem
|
||||||
|
elif args.split_name:
|
||||||
|
name = args.split_name
|
||||||
|
elif args.reference_audio:
|
||||||
|
name = Path(args.reference_audio).stem
|
||||||
|
else:
|
||||||
|
name = "results"
|
||||||
|
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
||||||
|
f.write(s)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if stats_client and stats_before:
|
||||||
|
print("Fetching inference statistics after running tasks...")
|
||||||
|
stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
||||||
|
|
||||||
|
print("Calculating statistics difference...")
|
||||||
|
stats = subtract_stats(stats_after, stats_before)
|
||||||
|
|
||||||
|
print("Fetching model config...")
|
||||||
|
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
||||||
|
|
||||||
|
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
||||||
|
|
||||||
|
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
||||||
|
json.dump(metadata, f, indent=4)
|
||||||
|
else:
|
||||||
|
print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not retrieve statistics or config: {e}")
|
||||||
|
finally:
|
||||||
|
if stats_client:
|
||||||
|
try:
|
||||||
|
print("Closing temporary async stats client...")
|
||||||
|
await stats_client.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error closing async stats client: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
async def run_main():
|
||||||
|
try:
|
||||||
|
await main()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred in main: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
asyncio.run(run_main())
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user