remove flow_cache

This commit is contained in:
lyuxiang.lx
2025-05-23 12:50:47 +08:00
parent 88f467a8ac
commit 68100c267a
14 changed files with 365 additions and 955 deletions

View File

@@ -33,12 +33,14 @@ class CosyVoiceModel:
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
fp16: bool = False,
trt_concurrent: int = 1):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
self.trt_concurrent = trt_concurrent
if self.fp16 is True:
self.llm.half()
self.flow.half()
@@ -85,23 +87,18 @@ class CosyVoiceModel:
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model):
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
if os.path.getsize(flow_decoder_estimator_model) == 0:
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
if isinstance(self, CosyVoice2Model):
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
else:
self.flow.decoder.estimator = estimator_engine.create_execution_context()
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)]
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
@@ -249,21 +246,21 @@ class CosyVoice2Model(CosyVoiceModel):
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False,
use_flow_cache: bool = False,
trt_concurrent: int = 1):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
# NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
self.flow.encoder.streaming = True
self.flow.decoder.estimator.streaming = True
self.hift = hift
self.fp16 = fp16
self.use_flow_cache = use_flow_cache
self.trt_concurrent = trt_concurrent
if self.fp16 is True:
self.llm.half()
self.flow.half()
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
# hift cache
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
@@ -278,56 +275,24 @@ class CosyVoice2Model(CosyVoiceModel):
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {}
self.trt_context_dict = {}
def init_flow_cache(self):
encoder_cache = {'offset': 0,
'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
'upsample_offset': 0,
'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
decoder_cache = {'offset': 0,
'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
if self.fp16 is True:
for cache in [encoder_cache, decoder_cache]:
for k, v in cache.items():
if isinstance(v, torch.Tensor):
cache[k] = v.half()
cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
return cache
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape"
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
cache=self.flow_cache_dict[uuid],
finalize=finalize)
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
@@ -362,7 +327,6 @@ class CosyVoice2Model(CosyVoiceModel):
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
self.flow_cache_dict[this_uuid] = self.init_flow_cache()
self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
@@ -370,27 +334,23 @@ class CosyVoice2Model(CosyVoiceModel):
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
p.start()
if stream is True:
assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):]
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:]
token_offset = 0
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
token_offset=token_offset,
uuid=this_uuid,
finalize=False)
# NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
token_offset += this_token_hop_len
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
@@ -399,18 +359,19 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
token_offset=token_offset,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
p.join()
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
token_offset=0,
uuid=this_uuid,
finalize=True,
speed=speed)
@@ -419,7 +380,6 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
self.trt_context_dict.pop(this_uuid)
if torch.cuda.is_available():