mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add flow cache inference code
This commit is contained in:
@@ -129,7 +129,7 @@ class CosyVoice:
|
||||
|
||||
class CosyVoice2(CosyVoice):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
@@ -151,9 +151,9 @@ class CosyVoice2(CosyVoice):
|
||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
|
||||
@@ -288,19 +288,20 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
llm: torch.nn.Module,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool):
|
||||
fp16: bool,
|
||||
use_flow_cache: bool):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
self.use_flow_cache = use_flow_cache
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
self.token_hop_len = self.flow.encoder.static_chunk_size
|
||||
# flow decoder required_cache_size
|
||||
# TODO 基模型训练时没有设置num_decoding_left_chunks,需要重新训一下才能指定flow_decoder_required_cache_size
|
||||
self.flow_decoder_required_cache_size = 999
|
||||
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
|
||||
self.token_hop_len = 25
|
||||
self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len
|
||||
# hift cache
|
||||
self.mel_cache_len = 8
|
||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||
@@ -339,7 +340,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
return cache
|
||||
|
||||
def trim_flow_cache(self, cache):
|
||||
if cache['decoder_cache']['down_blocks_kv_cache'].size(4) > self.flow_decoder_required_cache_size:
|
||||
if self.flow_decoder_required_cache_size > 0:
|
||||
cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
||||
cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
||||
cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
||||
@@ -399,10 +400,10 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
# NOTE flow model is only trained with static_chunk_size, so we need to trim flow prompt
|
||||
n_chunk = int(flow_prompt_speech_token.size(1) / self.token_hop_len)
|
||||
flow_prompt_speech_token = flow_prompt_speech_token[:, :n_chunk * self.token_hop_len]
|
||||
prompt_speech_feat = prompt_speech_feat[:, :n_chunk * self.token_hop_len * 2]
|
||||
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
|
||||
if self.use_flow_cache is True:
|
||||
flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:]
|
||||
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:]
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
|
||||
Reference in New Issue
Block a user