mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix(async_cosyvoice): 恢复原本文本令牌处理逻辑
- 在 Frontend 中,恢复原本逐个生成文本令牌 - 在 Model 类中,移除了不必要的日志信息和断言,简化了文本令牌的处理流程
This commit is contained in:
@@ -102,9 +102,8 @@ class CosyVoiceFrontEnd:
|
|||||||
def _extract_text_token_generator(self, text_generator):
|
def _extract_text_token_generator(self, text_generator):
|
||||||
for text in text_generator:
|
for text in text_generator:
|
||||||
text_token, _ = self._extract_text_token(text)
|
text_token, _ = self._extract_text_token(text)
|
||||||
# for i in range(text_token.shape[1]):
|
for i in range(text_token.shape[1]):
|
||||||
# yield text_token[:, i: i + 1]
|
yield text_token[:, i: i + 1]
|
||||||
yield text_token
|
|
||||||
|
|
||||||
def _extract_speech_token(self, speech):
|
def _extract_speech_token(self, speech):
|
||||||
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
||||||
|
|||||||
@@ -149,8 +149,6 @@ class VllmQwen2LM(Qwen2LM):
|
|||||||
need_add_tokens = output.token_ids[:-1]
|
need_add_tokens = output.token_ids[:-1]
|
||||||
else:
|
else:
|
||||||
need_add_tokens = output.token_ids
|
need_add_tokens = output.token_ids
|
||||||
# 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环
|
|
||||||
# yield need_add_tokens
|
|
||||||
for token in need_add_tokens:
|
for token in need_add_tokens:
|
||||||
yield token
|
yield token
|
||||||
|
|
||||||
@@ -186,18 +184,14 @@ class VllmQwen2LM(Qwen2LM):
|
|||||||
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
||||||
prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
|
prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
|
||||||
else:
|
else:
|
||||||
logging.info('not enough text token to decode, wait for more')
|
|
||||||
break
|
break
|
||||||
if len(prompt_speech_token) == 0:
|
if len(prompt_speech_token) == 0:
|
||||||
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
|
if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
|
||||||
logging.info('get fill token, need to append more text token')
|
|
||||||
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
if len(text_tokens_cache) >= self.mix_ratio[0]:
|
||||||
text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
|
text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
|
||||||
prompt_token_ids += text_tokens_temp
|
prompt_token_ids += text_tokens_temp
|
||||||
logging.info('append {} text token'.format(len(text_tokens_temp)))
|
|
||||||
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
|
||||||
else:
|
else:
|
||||||
logging.info('not enough text token to decode, wait for more')
|
|
||||||
continue
|
continue
|
||||||
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
|
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
|
||||||
last_tokens = output.token_ids
|
last_tokens = output.token_ids
|
||||||
@@ -205,19 +199,14 @@ class VllmQwen2LM(Qwen2LM):
|
|||||||
need_add_tokens = last_tokens[:-1]
|
need_add_tokens = last_tokens[:-1]
|
||||||
else:
|
else:
|
||||||
need_add_tokens = last_tokens
|
need_add_tokens = last_tokens
|
||||||
# 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环
|
|
||||||
# yield need_add_tokens
|
|
||||||
for token in need_add_tokens:
|
for token in need_add_tokens:
|
||||||
yield token
|
yield token
|
||||||
prompt_token_ids.extend(need_add_tokens)
|
prompt_token_ids.extend(need_add_tokens)
|
||||||
prompt_token_ids += text_tokens_cache + [self.task_token_id]
|
prompt_token_ids += text_tokens_cache + [self.task_token_id]
|
||||||
logging.info('no more text token, decode until met eos')
|
|
||||||
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
|
for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
|
||||||
if output.token_ids[-1] == 6561:
|
if output.token_ids[-1] == 6561:
|
||||||
need_add_tokens = output.token_ids[:-1]
|
need_add_tokens = output.token_ids[:-1]
|
||||||
else:
|
else:
|
||||||
need_add_tokens = output.token_ids
|
need_add_tokens = output.token_ids
|
||||||
# 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环
|
|
||||||
# yield need_add_tokens
|
|
||||||
for token in need_add_tokens:
|
for token in need_add_tokens:
|
||||||
yield token
|
yield token
|
||||||
|
|||||||
Reference in New Issue
Block a user