mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update func export_codec_vllm
This commit is contained in:
@@ -156,7 +156,7 @@ class CosyVoice2(CosyVoice):
|
|||||||
self.model.export_codec_vllm(''.join([model_dir, '/codec_vllm_model']))
|
self.model.export_codec_vllm(''.join([model_dir, '/codec_vllm_model']))
|
||||||
engine_args = EngineArgs(model=''.join([model_dir, '/codec_vllm_model']),
|
engine_args = EngineArgs(model=''.join([model_dir, '/codec_vllm_model']),
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
gpu_memory_utilization=0.1)
|
gpu_memory_utilization=0.2)
|
||||||
self.vllm_codec_engine = LLMEngine.from_engine_args(engine_args)
|
self.vllm_codec_engine = LLMEngine.from_engine_args(engine_args)
|
||||||
self.model.vllm_codec_engine = self.vllm_codec_engine
|
self.model.vllm_codec_engine = self.vllm_codec_engine
|
||||||
|
|
||||||
|
|||||||
@@ -347,6 +347,9 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
self.llm.llm.model.to(dtype)
|
self.llm.llm.model.to(dtype)
|
||||||
tmp_vocab_size = self.llm.llm.model.config.vocab_size
|
tmp_vocab_size = self.llm.llm.model.config.vocab_size
|
||||||
tmp_tie_embedding = self.llm.llm.model.config.tie_word_embeddings
|
tmp_tie_embedding = self.llm.llm.model.config.tie_word_embeddings
|
||||||
|
del self.llm.llm.model.generation_config.eos_token_id
|
||||||
|
del self.llm.llm.model.config.bos_token_id
|
||||||
|
del self.llm.llm.model.config.eos_token_id
|
||||||
self.llm.llm.model.config.vocab_size = pad_vocab_size
|
self.llm.llm.model.config.vocab_size = pad_vocab_size
|
||||||
self.llm.llm.model.config.tie_word_embeddings = False
|
self.llm.llm.model.config.tie_word_embeddings = False
|
||||||
self.llm.llm.model.config.use_bias = True
|
self.llm.llm.model.config.use_bias = True
|
||||||
|
|||||||
@@ -343,7 +343,7 @@ class Qwen2LM(TransformerLM):
|
|||||||
max_tokens=max_len)
|
max_tokens=max_len)
|
||||||
request_id = uuid.uuid4()
|
request_id = uuid.uuid4()
|
||||||
vllm_codec_engine.add_request(request_id,
|
vllm_codec_engine.add_request(request_id,
|
||||||
{"prompt_embeds": lm_input.to(torch.bfloat16).to(device)},
|
{"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(device)},
|
||||||
sampling_params)
|
sampling_params)
|
||||||
## generator
|
## generator
|
||||||
out_token_ids = []
|
out_token_ids = []
|
||||||
|
|||||||
Reference in New Issue
Block a user