From c0f6a474f36643fb061338115e71593776d8345b Mon Sep 17 00:00:00 2001 From: qihua Date: Sat, 8 Mar 2025 16:03:35 +0800 Subject: [PATCH] =?UTF-8?q?fix(async=5Fcosyvoice):=20=E6=81=A2=E5=A4=8D?= =?UTF-8?q?=E5=8E=9F=E6=9C=AC=E6=96=87=E6=9C=AC=E4=BB=A4=E7=89=8C=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 Frontend 中,恢复原本逐个生成文本令牌 - 在 Model 类中,移除了不必要的日志信息和断言,简化了文本令牌的处理流程 --- cosyvoice/cli/frontend.py | 5 ++--- cosyvoice/llm/llm_vllm.py | 11 ----------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 5aa2d34..834f0b0 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -102,9 +102,8 @@ class CosyVoiceFrontEnd: def _extract_text_token_generator(self, text_generator): for text in text_generator: text_token, _ = self._extract_text_token(text) - # for i in range(text_token.shape[1]): - # yield text_token[:, i: i + 1] - yield text_token + for i in range(text_token.shape[1]): + yield text_token[:, i: i + 1] def _extract_speech_token(self, speech): assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' diff --git a/cosyvoice/llm/llm_vllm.py b/cosyvoice/llm/llm_vllm.py index 839bf88..a864a04 100644 --- a/cosyvoice/llm/llm_vllm.py +++ b/cosyvoice/llm/llm_vllm.py @@ -149,8 +149,6 @@ class VllmQwen2LM(Qwen2LM): need_add_tokens = output.token_ids[:-1] else: need_add_tokens = output.token_ids - # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 - # yield need_add_tokens for token in need_add_tokens: yield token @@ -186,18 +184,14 @@ class VllmQwen2LM(Qwen2LM): text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:] else: - logging.info('not enough text token to decode, wait for more') break if len(prompt_speech_token) == 0: if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1: - logging.info('get fill token, need to append more text token') if len(text_tokens_cache) >= self.mix_ratio[0]: text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]] prompt_token_ids += text_tokens_temp - logging.info('append {} text token'.format(len(text_tokens_temp))) text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:] else: - logging.info('not enough text token to decode, wait for more') continue for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]): last_tokens = output.token_ids @@ -205,19 +199,14 @@ class VllmQwen2LM(Qwen2LM): need_add_tokens = last_tokens[:-1] else: need_add_tokens = last_tokens - # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 - # yield need_add_tokens for token in need_add_tokens: yield token prompt_token_ids.extend(need_add_tokens) prompt_token_ids += text_tokens_cache + [self.task_token_id] - logging.info('no more text token, decode until met eos') for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]): if output.token_ids[-1] == 6561: need_add_tokens = output.token_ids[:-1] else: need_add_tokens = output.token_ids - # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环 - # yield need_add_tokens for token in need_add_tokens: yield token