add flow cache inference code

This commit is contained in:
lyuxiang.lx
2025-04-07 21:23:09 +08:00
parent a69b7e275d
commit 39ffc50dec
4 changed files with 19 additions and 18 deletions

View File

@@ -129,7 +129,7 @@ class CosyVoice:
class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir
self.fp16 = fp16
@@ -151,9 +151,9 @@ class CosyVoice2(CosyVoice):
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))