mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update
This commit is contained in:
@@ -23,7 +23,7 @@ import torch
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
|
||||
|
||||
def get_args():
|
||||
@@ -37,6 +37,15 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
def get_optimized_script(model, preserved_attrs=[]):
|
||||
script = torch.jit.script(model)
|
||||
if preserved_attrs != []:
|
||||
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
|
||||
else:
|
||||
script = torch.jit.freeze(script)
|
||||
script = torch.jit.optimize_for_inference(script)
|
||||
return script
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
@@ -46,28 +55,35 @@ def main():
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
|
||||
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
||||
try:
|
||||
model = CosyVoice(args.model_dir)
|
||||
except:
|
||||
try:
|
||||
model = CosyVoice2(args.model_dir)
|
||||
except:
|
||||
raise TypeError('no valid model_type!')
|
||||
|
||||
# 1. export llm text_encoder
|
||||
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
|
||||
script = torch.jit.script(llm_text_encoder)
|
||||
script = torch.jit.freeze(script)
|
||||
script = torch.jit.optimize_for_inference(script)
|
||||
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
||||
if not isinstance(model, CosyVoice2):
|
||||
# 1. export llm text_encoder
|
||||
llm_text_encoder = model.model.llm.text_encoder
|
||||
script = get_optimized_script(llm_text_encoder)
|
||||
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(llm_text_encoder.half())
|
||||
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
||||
|
||||
# 2. export llm llm
|
||||
llm_llm = cosyvoice.model.llm.llm.half()
|
||||
script = torch.jit.script(llm_llm)
|
||||
script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
|
||||
script = torch.jit.optimize_for_inference(script)
|
||||
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
||||
# 2. export llm llm
|
||||
llm_llm = model.model.llm.llm
|
||||
script = get_optimized_script(llm_llm, ['forward_chunk'])
|
||||
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
|
||||
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
||||
|
||||
# 3. export flow encoder
|
||||
flow_encoder = cosyvoice.model.flow.encoder
|
||||
script = torch.jit.script(flow_encoder)
|
||||
script = torch.jit.freeze(script)
|
||||
script = torch.jit.optimize_for_inference(script)
|
||||
flow_encoder = model.model.flow.encoder
|
||||
script = get_optimized_script(flow_encoder)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(flow_encoder.half())
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user