mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
optimize flow cache code
This commit is contained in:
@@ -301,7 +301,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
self.flow.half()
|
self.flow.half()
|
||||||
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
|
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
|
||||||
self.token_hop_len = 25
|
self.token_hop_len = 25
|
||||||
self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
|
self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
|
||||||
# hift cache
|
# hift cache
|
||||||
self.mel_cache_len = 8
|
self.mel_cache_len = 8
|
||||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||||
@@ -325,11 +325,11 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
|
'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
|
||||||
decoder_cache = {'offset': 0,
|
decoder_cache = {'offset': 0,
|
||||||
'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
|
'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
|
||||||
'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
|
'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
||||||
'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
|
'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
|
||||||
'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, 0, 512, 2).to(self.device),
|
'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
||||||
'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
|
'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
|
||||||
'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
|
'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
||||||
'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
|
'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
|
||||||
if self.fp16 is True:
|
if self.fp16 is True:
|
||||||
for cache in [encoder_cache, decoder_cache]:
|
for cache in [encoder_cache, decoder_cache]:
|
||||||
@@ -339,13 +339,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
|
cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def trim_flow_cache(self, cache):
|
|
||||||
if self.flow_decoder_required_cache_size > 0 and cache['decoder_cache']['down_blocks_kv_cache'].size(4) > self.flow_decoder_required_cache_size:
|
|
||||||
cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
|
||||||
cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
|
||||||
cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
|
||||||
return cache
|
|
||||||
|
|
||||||
def load_jit(self, flow_encoder_model):
|
def load_jit(self, flow_encoder_model):
|
||||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
self.flow.encoder = flow_encoder
|
self.flow.encoder = flow_encoder
|
||||||
@@ -369,7 +362,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
embedding=embedding.to(self.device),
|
embedding=embedding.to(self.device),
|
||||||
cache=self.flow_cache_dict[uuid],
|
cache=self.flow_cache_dict[uuid],
|
||||||
finalize=finalize)
|
finalize=finalize)
|
||||||
self.flow_cache_dict[uuid] = self.trim_flow_cache(self.flow_cache_dict[uuid])
|
|
||||||
# append hift cache
|
# append hift cache
|
||||||
if self.hift_cache_dict[uuid] is not None:
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||||
|
|||||||
@@ -243,11 +243,6 @@ class CausalConditionalCFM(ConditionalCFM):
|
|||||||
# Or in future might add like a return_all_steps flag
|
# Or in future might add like a return_all_steps flag
|
||||||
sol = []
|
sol = []
|
||||||
|
|
||||||
# estimator cache for each step
|
|
||||||
down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
|
|
||||||
mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x)
|
|
||||||
up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
|
|
||||||
|
|
||||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||||
@@ -255,6 +250,7 @@ class CausalConditionalCFM(ConditionalCFM):
|
|||||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||||
|
flow_cache_size = cache['down_blocks_kv_cache'].shape[4]
|
||||||
for step in range(1, len(t_span)):
|
for step in range(1, len(t_span)):
|
||||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||||
x_in[:] = x
|
x_in[:] = x
|
||||||
@@ -271,13 +267,15 @@ class CausalConditionalCFM(ConditionalCFM):
|
|||||||
cond_in,
|
cond_in,
|
||||||
cache_step
|
cache_step
|
||||||
)
|
)
|
||||||
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
|
# NOTE if smaller than flow_cache_size, means last chunk, no need to cache
|
||||||
down_blocks_kv_cache_new[step - 1] = cache_step[1]
|
if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
|
||||||
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
|
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
|
||||||
mid_blocks_kv_cache_new[step - 1] = cache_step[3]
|
cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:,:,:,-flow_cache_size:]
|
||||||
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
|
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
|
||||||
up_blocks_kv_cache_new[step - 1] = cache_step[5]
|
cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:,:,:,-flow_cache_size:]
|
||||||
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
|
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
|
||||||
|
cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:,:,:,-flow_cache_size:]
|
||||||
|
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
|
||||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||||
x = x + dt * dphi_dt
|
x = x + dt * dphi_dt
|
||||||
@@ -285,9 +283,6 @@ class CausalConditionalCFM(ConditionalCFM):
|
|||||||
sol.append(x)
|
sol.append(x)
|
||||||
if step < len(t_span) - 1:
|
if step < len(t_span) - 1:
|
||||||
dt = t_span[step + 1] - t
|
dt = t_span[step + 1] - t
|
||||||
cache['down_blocks_kv_cache'] = torch.concat([cache['down_blocks_kv_cache'], down_blocks_kv_cache_new], dim=4)
|
|
||||||
cache['mid_blocks_kv_cache'] = torch.concat([cache['mid_blocks_kv_cache'], mid_blocks_kv_cache_new], dim=4)
|
|
||||||
cache['up_blocks_kv_cache'] = torch.concat([cache['up_blocks_kv_cache'], up_blocks_kv_cache_new], dim=4)
|
|
||||||
return sol[-1].float(), cache
|
return sol[-1].float(), cache
|
||||||
|
|
||||||
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
|
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
|
||||||
|
|||||||
Reference in New Issue
Block a user