mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix export_jit.py
This commit is contained in:
@@ -24,7 +24,6 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|||||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||||
from cosyvoice.cli.cosyvoice import AutoModel
|
from cosyvoice.cli.cosyvoice import AutoModel
|
||||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
|
||||||
from cosyvoice.utils.file_utils import logging
|
from cosyvoice.utils.file_utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -60,7 +59,7 @@ def main():
|
|||||||
|
|
||||||
model = AutoModel(model_dir=args.model_dir)
|
model = AutoModel(model_dir=args.model_dir)
|
||||||
|
|
||||||
if isinstance(model.model, CosyVoiceModel):
|
if model.__class__.__name__ == 'CosyVoice':
|
||||||
# 1. export llm text_encoder
|
# 1. export llm text_encoder
|
||||||
llm_text_encoder = model.model.llm.text_encoder
|
llm_text_encoder = model.model.llm.text_encoder
|
||||||
script = get_optimized_script(llm_text_encoder)
|
script = get_optimized_script(llm_text_encoder)
|
||||||
@@ -84,7 +83,7 @@ def main():
|
|||||||
script = get_optimized_script(flow_encoder.half())
|
script = get_optimized_script(flow_encoder.half())
|
||||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||||
logging.info('successfully export flow_encoder')
|
logging.info('successfully export flow_encoder')
|
||||||
elif isinstance(model.model, CosyVoice2Model):
|
elif model.__class__.__name__ == 'CosyVoice2':
|
||||||
# 1. export flow encoder
|
# 1. export flow encoder
|
||||||
flow_encoder = model.model.flow.encoder
|
flow_encoder = model.model.flow.encoder
|
||||||
script = get_optimized_script(flow_encoder)
|
script = get_optimized_script(flow_encoder)
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class CosyVoice:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
||||||
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
|
||||||
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
||||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||||
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class CosyVoiceModel:
|
|||||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||||
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
||||||
if isinstance(text, Generator):
|
if isinstance(text, Generator):
|
||||||
assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
|
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
|
||||||
for i in self.llm.inference_bistream(text=text,
|
for i in self.llm.inference_bistream(text=text,
|
||||||
prompt_text=prompt_text.to(self.device),
|
prompt_text=prompt_text.to(self.device),
|
||||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
|||||||
Reference in New Issue
Block a user