From 07e477519b8326c0469f2278ec89bad1de98bb8e Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Thu, 23 Jan 2025 10:12:06 +0800 Subject: [PATCH] add llm bistream --- README.md | 10 ++ cosyvoice/cli/cosyvoice.py | 3 +- cosyvoice/cli/frontend.py | 24 ++++- cosyvoice/cli/model.py | 27 +++-- cosyvoice/llm/llm.py | 130 ++++++++++++++++++++----- cosyvoice/utils/common.py | 2 +- docker/Dockerfile | 2 +- examples/libritts/cosyvoice2/cosyvoice | 1 + examples/libritts/cosyvoice2/tools | 1 + tools/extract_speech_token.py | 2 +- 10 files changed, 163 insertions(+), 39 deletions(-) create mode 120000 examples/libritts/cosyvoice2/cosyvoice create mode 120000 examples/libritts/cosyvoice2/tools diff --git a/README.md b/README.md index bd02df7..cbe217a 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,16 @@ for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒 # instruct usage for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)): torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) + +# bistream usage, you can use generator as input, this is useful when using text llm model as input +# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length +def text_generator(): + yield '收到好友从远方寄来的生日礼物,' + yield '那份意外的惊喜与深深的祝福' + yield '让我心中充满了甜蜜的快乐,' + yield '笑容如花儿般绽放。' +for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator, '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): + torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) ``` **CosyVoice Usage** diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 2da3d0a..e2d62e2 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -13,6 +13,7 @@ # limitations under the License. import os import time +from typing import Generator from tqdm import tqdm from hyperpyyaml import load_hyperpyyaml from modelscope import snapshot_download @@ -76,7 +77,7 @@ class CosyVoice: def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True): prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend) for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)): - if len(i) < 0.5 * len(prompt_text): + if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text): logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text)) model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate) start_time = time.time() diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index ab59a93..6e10f00 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial +from typing import Generator import json import onnxruntime import torch @@ -31,6 +32,7 @@ except ImportError: from tn.chinese.normalizer import Normalizer as ZhNormalizer from tn.english.normalizer import Normalizer as EnNormalizer use_ttsfrd = False +from cosyvoice.utils.file_utils import logging from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation @@ -71,10 +73,21 @@ class CosyVoiceFrontEnd: self.inflect_parser = inflect.engine() def _extract_text_token(self, text): - text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special) - text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) - text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) - return text_token, text_token_len + if isinstance(text, Generator): + logging.info('get tts_text generator, will return _extract_text_token_generator!') + # NOTE add a dummy text_token_len for compatibility + return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device) + else: + text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special) + text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) + text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) + return text_token, text_token_len + + def _extract_text_token_generator(self, text_generator): + for text in text_generator: + text_token, _ = self._extract_text_token(text) + for i in range(text_token.shape[1]): + yield text_token[:, i: i + 1] def _extract_speech_token(self, speech): assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' @@ -106,6 +119,9 @@ class CosyVoiceFrontEnd: return speech_feat, speech_feat_len def text_normalize(self, text, split=True, text_frontend=True): + if isinstance(text, Generator): + logging.info('get tts_text generator, will skip text_normalize!') + return [text] if text_frontend is False: return [text] if split is True else text text = text.strip() diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 9995c44..9ebf8cb 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Generator import torch import numpy as np import threading @@ -99,14 +100,24 @@ class CosyVoiceModel: def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): with self.llm_context: - for i in self.llm.inference(text=text.to(self.device), - text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), - prompt_text=prompt_text.to(self.device), - prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), - prompt_speech_token=llm_prompt_speech_token.to(self.device), - prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), - embedding=llm_embedding.to(self.device)): - self.tts_speech_token_dict[uuid].append(i) + if isinstance(text, Generator): + assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' + for i in self.llm.inference_bistream(text=text, + prompt_text=prompt_text.to(self.device), + prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + embedding=llm_embedding.to(self.device)): + self.tts_speech_token_dict[uuid].append(i) + else: + for i in self.llm.inference(text=text.to(self.device), + text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), + prompt_text=prompt_text.to(self.device), + prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + embedding=llm_embedding.to(self.device)): + self.tts_speech_token_dict[uuid].append(i) self.llm_end_dict[uuid] = True def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index f9bdf3e..78d1f9c 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -20,6 +20,7 @@ from torch.nn.utils.rnn import pad_sequence, unpad_sequence from cosyvoice.utils.common import IGNORE_ID from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss from cosyvoice.utils.common import th_accuracy +from cosyvoice.utils.file_utils import logging class TransformerLM(torch.nn.Module): @@ -144,10 +145,14 @@ class TransformerLM(torch.nn.Module): sampling: int, ignore_eos: bool = True, ): + num_trials, max_trials = 0, 100 while True: top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) if (not ignore_eos) or (self.speech_token_size not in top_ids): break + num_trials += 1 + if num_trials > max_trials: + raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials)) return top_ids @torch.inference_mode() @@ -239,7 +244,7 @@ class Qwen2Encoder(torch.nn.Module): return xs, new_cache -class Qwen2LM(torch.nn.Module): +class Qwen2LM(TransformerLM): def __init__( self, llm_input_size: int, @@ -249,8 +254,9 @@ class Qwen2LM(torch.nn.Module): sampling: Callable, length_normalized_loss: bool = True, lsm_weight: float = 0.0, + mix_ratio: List[int] = [5, 15], ): - super().__init__() + torch.nn.Module.__init__(self) self.llm_input_size = llm_input_size self.llm_output_size = llm_output_size self.speech_token_size = speech_token_size @@ -275,23 +281,7 @@ class Qwen2LM(torch.nn.Module): # 4. sampling method self.sampling = sampling - - def sampling_ids( - self, - weighted_scores: torch.Tensor, - decoded_tokens: List, - sampling: int, - ignore_eos: bool = True, - ): - num_trials, max_trials = 0, 100 - while True: - top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) - if (not ignore_eos) or (self.speech_token_size not in top_ids): - break - num_trials += 1 - if num_trials > max_trials: - raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials)) - return top_ids + self.mix_ratio = mix_ratio @torch.inference_mode() def inference( @@ -312,9 +302,6 @@ class Qwen2LM(torch.nn.Module): text_len += prompt_text_len text = self.llm.model.model.embed_tokens(text) - # 2. encode embedding - embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype) - # 3. concat llm_input sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) @@ -322,7 +309,7 @@ class Qwen2LM(torch.nn.Module): prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) else: prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device) - lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1) + lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1) # 4. cal min/max_length min_len = int((text_len - prompt_text_len) * min_token_text_ratio) @@ -345,3 +332,100 @@ class Qwen2LM(torch.nn.Module): yield top_ids out_tokens.append(top_ids) lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + @torch.inference_mode() + def inference_bistream( + self, + text: Generator, + prompt_text: torch.Tensor, + prompt_text_len: torch.Tensor, + prompt_speech_token: torch.Tensor, + prompt_speech_token_len: torch.Tensor, + embedding: torch.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + + device = prompt_text.device + # 1. prepare input + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) + lm_input = torch.concat([sos_eos_emb], dim=1) + + # 2. iterate text + out_tokens = [] + cache = None + # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5 + text_cache = self.llm.model.model.embed_tokens(prompt_text) + next_fill_index = -1 + for this_text in text: + text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1) + # prompt_speech_token_emb not empty, try append to lm_input + while prompt_speech_token_emb.size(1) != 0: + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]] + logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1))) + lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1) + text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:] + else: + logging.info('not enough text token to decode, wait for more') + break + # no prompt_speech_token_emb remain, can decode some speech token + if prompt_speech_token_emb.size(1) == 0: + if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1): + logging.info('get fill token, need to append more text token') + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text = text_cache[:, :self.mix_ratio[0]] + logging.info('append {} text token'.format(lm_input_text.size(1))) + lm_input = torch.concat([lm_input, lm_input_text], dim=1) + text_cache = text_cache[:, self.mix_ratio[0]:] + else: + logging.info('not enough text token to decode, wait for more') + continue + while True: + seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + if next_fill_index != -1 and len(out_tokens) == next_fill_index: + top_ids = self.speech_token_size + 2 + next_fill_index += (self.mix_ratio[1] + 1) + else: + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item() + if top_ids == self.speech_token_size + 2: + next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1 + logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index)) + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size + 2: + break + else: + raise ValueError('should not get token {}'.format(top_ids)) + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + # 3. final decode + lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1) + logging.info('no more text token, decode until met eos') + while True: + seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item() + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size: + break + else: + raise ValueError('should not get token {}'.format(top_ids)) + # in stream mode, yield token one by one + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index b356f0c..3e61a8c 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -162,5 +162,5 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: # attention mask bias # NOTE(Mddct): torch.finfo jit issues # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min - mask = (1.0 - mask) * torch.finfo(dtype).min + mask = (1.0 - mask) * -1.0e+10 return mask diff --git a/docker/Dockerfile b/docker/Dockerfile index 60b101f..59908c0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -34,7 +34,7 @@ RUN conda config --add channels conda-forge && \ # ~conda # ================================================================== -RUN conda create -y -n ${VENV} python=3.8 +RUN conda create -y -n ${VENV} python=3.10 ENV CONDA_DEFAULT_ENV=${VENV} ENV PATH /opt/conda/bin:/opt/conda/envs/${VENV}/bin:$PATH diff --git a/examples/libritts/cosyvoice2/cosyvoice b/examples/libritts/cosyvoice2/cosyvoice new file mode 120000 index 0000000..3903806 --- /dev/null +++ b/examples/libritts/cosyvoice2/cosyvoice @@ -0,0 +1 @@ +../../../cosyvoice \ No newline at end of file diff --git a/examples/libritts/cosyvoice2/tools b/examples/libritts/cosyvoice2/tools new file mode 120000 index 0000000..c92f417 --- /dev/null +++ b/examples/libritts/cosyvoice2/tools @@ -0,0 +1 @@ +../../../tools \ No newline at end of file diff --git a/tools/extract_speech_token.py b/tools/extract_speech_token.py index 26aa296..776b6cf 100755 --- a/tools/extract_speech_token.py +++ b/tools/extract_speech_token.py @@ -24,7 +24,7 @@ import whisper def single_job(utt): - audio, sample_rate = torchaudio.load(utt2wav[utt]) + audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile') if sample_rate != 16000: audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) if audio.shape[1] / 16000 > 30: