From 3e12bb86bdf9855dd44a14a1bcb83e07f0a1e00f Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Mon, 26 May 2025 18:03:15 +0800 Subject: [PATCH] fix trt wrapper bug --- cosyvoice/cli/model.py | 11 ++++++++++- cosyvoice/utils/mask.py | 3 +++ examples/libritts/cosyvoice2/conf/cosyvoice2.yaml | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index aa110b1..811b2cb 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -59,6 +59,9 @@ class CosyVoiceModel: self.stream_scale_factor = 1 assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) + for _ in range(trt_concurrent): + self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()) self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -66,6 +69,7 @@ class CosyVoiceModel: self.mel_overlap_dict = {} self.flow_cache_dict = {} self.hift_cache_dict = {} + self.trt_context_dict = {} def load(self, llm_model, flow_model, hift_model): self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) @@ -176,11 +180,13 @@ class CosyVoiceModel: prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) + this_trt_context = self.trt_context_pool.get() with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) + self.trt_context_dict[this_uuid] = this_trt_context if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) else: @@ -234,6 +240,8 @@ class CosyVoiceModel: self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) + self.trt_context_pool.put(self.trt_context_dict[this_uuid]) + self.trt_context_dict.pop(this_uuid) if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.current_stream().synchronize() @@ -324,10 +332,11 @@ class CosyVoice2Model(CosyVoiceModel): prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) + this_trt_context = self.trt_context_pool.get() with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None - self.trt_context_dict[this_uuid] = self.trt_context_pool.get() + self.trt_context_dict[this_uuid] = this_trt_context if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) else: diff --git a/cosyvoice/utils/mask.py b/cosyvoice/utils/mask.py index c966cc9..5d3dfd6 100644 --- a/cosyvoice/utils/mask.py +++ b/cosyvoice/utils/mask.py @@ -230,6 +230,9 @@ def add_optional_chunk_mask(xs: torch.Tensor, else: chunk_masks = masks assert chunk_masks.dtype == torch.bool + if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: + print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') + chunk_masks[chunk_masks.sum(dim=-1) == 0] = True return chunk_masks diff --git a/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml b/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml index d6bdeb6..84d1bd5 100644 --- a/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml +++ b/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml @@ -15,7 +15,7 @@ token_mel_ratio: 2 # stream related params chunk_size: 25 # streaming inference chunk size, in token -num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size, <0 means use all left chunks +num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks # model params # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml.