mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
fix trt wrapper bug
This commit is contained in:
@@ -59,6 +59,9 @@ class CosyVoiceModel:
|
||||
self.stream_scale_factor = 1
|
||||
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
for _ in range(trt_concurrent):
|
||||
self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
||||
self.lock = threading.Lock()
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
@@ -66,6 +69,7 @@ class CosyVoiceModel:
|
||||
self.mel_overlap_dict = {}
|
||||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
self.trt_context_dict = {}
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||
@@ -176,11 +180,13 @@ class CosyVoiceModel:
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
this_trt_context = self.trt_context_pool.get()
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
||||
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
||||
self.trt_context_dict[this_uuid] = this_trt_context
|
||||
if source_speech_token.shape[1] == 0:
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
else:
|
||||
@@ -234,6 +240,8 @@ class CosyVoiceModel:
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
|
||||
self.trt_context_dict.pop(this_uuid)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
@@ -324,10 +332,11 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
this_trt_context = self.trt_context_pool.get()
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
|
||||
self.trt_context_dict[this_uuid] = this_trt_context
|
||||
if source_speech_token.shape[1] == 0:
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
else:
|
||||
|
||||
@@ -230,6 +230,9 @@ def add_optional_chunk_mask(xs: torch.Tensor,
|
||||
else:
|
||||
chunk_masks = masks
|
||||
assert chunk_masks.dtype == torch.bool
|
||||
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
||||
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
||||
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
||||
return chunk_masks
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user