mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 09:59:23 +08:00
update export_codec_vllm
This commit is contained in:
@@ -331,6 +331,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||||
|
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
# lm_head
|
||||||
new_lm_head = nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
|
new_lm_head = nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
new_lm_head.weight[:vocab_size] = self.llm.llm_decoder.weight
|
new_lm_head.weight[:vocab_size] = self.llm.llm_decoder.weight
|
||||||
@@ -339,6 +340,8 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
new_lm_head.bias[vocab_size:] = 0
|
new_lm_head.bias[vocab_size:] = 0
|
||||||
self.llm.llm.model.lm_head = new_lm_head
|
self.llm.llm.model.lm_head = new_lm_head
|
||||||
new_codec_embed = nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
new_codec_embed = nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
|
||||||
|
# embed_tokens
|
||||||
|
embed_tokens = self.llm.llm.model.model.embed_tokens
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
new_codec_embed.weight[:vocab_size] = self.llm.speech_embedding.weight
|
new_codec_embed.weight[:vocab_size] = self.llm.speech_embedding.weight
|
||||||
new_codec_embed.weight[vocab_size:] = 0
|
new_codec_embed.weight[vocab_size:] = 0
|
||||||
@@ -356,6 +359,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
self.llm.llm.model.save_pretrained(model_path)
|
self.llm.llm.model.save_pretrained(model_path)
|
||||||
self.llm.llm.model.config.vocab_size = tmp_vocab_size
|
self.llm.llm.model.config.vocab_size = tmp_vocab_size
|
||||||
self.llm.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
self.llm.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||||
|
self.llm.llm.model.set_input_embeddings(embed_tokens)
|
||||||
|
|
||||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
||||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||||
|
|||||||
@@ -353,12 +353,14 @@ class Qwen2LM(TransformerLM):
|
|||||||
if str(request_output.request_id) != str(request_id):
|
if str(request_output.request_id) != str(request_id):
|
||||||
continue
|
continue
|
||||||
if not request_output.finished:
|
if not request_output.finished:
|
||||||
print(f"Partial request output: {request_output}")
|
# print(f"Partial request output: {request_output}")
|
||||||
out_token = list(request_output.outputs[0].token_ids)[-1]
|
out_token = list(request_output.outputs[0].token_ids)[-1]
|
||||||
yield out_token
|
yield out_token
|
||||||
out_token_ids.append(out_token)
|
out_token_ids.append(out_token)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
if not vllm_codec_engine.has_unfinished_requests():
|
||||||
|
break
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference_bistream(
|
def inference_bistream(
|
||||||
|
|||||||
Reference in New Issue
Block a user