This commit is contained in:
lyuxiang.lx
2025-04-15 15:00:29 +08:00
parent 36aec2c0f7
commit c07cd3d730

View File

@@ -270,11 +270,11 @@ class CausalConditionalCFM(ConditionalCFM):
# NOTE if smaller than flow_cache_size, means last chunk, no need to cache
if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:,:,:,-flow_cache_size:]
cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:,:,:,-flow_cache_size:]
cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:,:,:,-flow_cache_size:]
cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)