diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 39464ca..7f0211d 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -54,11 +54,13 @@ class CosyVoice: '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) if load_trt: + self.estimator_count = configs.get('estimator_count', 1) self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), - self.fp16) + self.fp16, self.estimator_count) del configs + def list_available_spks(self): spks = list(self.frontend.spk2info.keys()) return spks @@ -178,11 +180,13 @@ class CosyVoice2(CosyVoice): if load_jit: self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) if load_trt: + self.estimator_count = configs.get('estimator_count', 1) self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), - self.fp16) + self.fp16, self.estimator_count) del configs + def inference_instruct(self, *args, **kwargs): raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!') diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index c0d25ba..d72816a 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -22,7 +22,8 @@ from contextlib import nullcontext import uuid from cosyvoice.utils.common import fade_in_out from cosyvoice.utils.file_utils import convert_onnx_to_trt - +from cosyvoice.flow.flow_matching import EstimatorWrapper +import queue class CosyVoiceModel: @@ -66,6 +67,12 @@ class CosyVoiceModel: self.flow_cache_dict = {} self.hift_cache_dict = {} + self.stream_context_pool = queue.Queue() + for _ in range(10): + self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()) + + self.is_cuda_available = torch.cuda.is_available() + def load(self, llm_model, flow_model, hift_model): self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) self.llm.to(self.device).eval() @@ -84,7 +91,7 @@ class CosyVoiceModel: flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder - def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): + def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16, estimator_count=1): assert torch.cuda.is_available(), 'tensorrt only supports gpu!' if not os.path.exists(flow_decoder_estimator_model): convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16) @@ -96,7 +103,7 @@ class CosyVoiceModel: self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) if self.flow.decoder.estimator_engine is None: raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model)) - self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() + self.flow.decoder.estimator = EstimatorWrapper(self.flow.decoder.estimator_engine, estimator_count=estimator_count) def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): with self.llm_context: @@ -122,13 +129,13 @@ class CosyVoiceModel: def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): tts_mel, flow_cache = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - flow_cache=self.flow_cache_dict[uuid]) + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + flow_cache=self.flow_cache_dict[uuid]) self.flow_cache_dict[uuid] = flow_cache # mel overlap fade in out @@ -148,8 +155,8 @@ class CosyVoiceModel: if self.hift_cache_dict[uuid] is not None: tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:], - 'source': tts_source[:, :, -self.source_cache_len:], - 'speech': tts_speech[:, -self.source_cache_len:]} + 'source': tts_source[:, :, -self.source_cache_len:], + 'speech': tts_speech[:, -self.source_cache_len:]} tts_speech = tts_speech[:, :-self.source_cache_len] else: if speed != 1.0: @@ -166,63 +173,70 @@ class CosyVoiceModel: flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread - this_uuid = str(uuid.uuid1()) - with self.lock: - self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False - self.hift_cache_dict[this_uuid] = None - self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) - self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) - p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) - p.start() - if stream is True: - token_hop_len = self.token_min_hop_len - while True: - time.sleep(0.1) - if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ - .unsqueeze(dim=0) - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - finalize=False) - yield {'tts_speech': this_tts_speech.cpu()} - with self.lock: - self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:] - # increase token_hop_len for better speech quality - token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) - if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: - break - p.join() - # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - finalize=True) - yield {'tts_speech': this_tts_speech.cpu()} - else: - # deal with all tokens - p.join() - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - finalize=True, - speed=speed) - yield {'tts_speech': this_tts_speech.cpu()} - with self.lock: - self.tts_speech_token_dict.pop(this_uuid) - self.llm_end_dict.pop(this_uuid) - self.mel_overlap_dict.pop(this_uuid) - self.hift_cache_dict.pop(this_uuid) - self.flow_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() + + stream_context = self.stream_context_pool.get() + with stream_context: + + this_uuid = str(uuid.uuid1()) + with self.lock: + self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False + self.hift_cache_dict[this_uuid] = None + self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) + self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) + p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) + p.start() + if stream is True: + token_hop_len = self.token_min_hop_len + while True: + time.sleep(0.1) + if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ + .unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=False) + yield {'tts_speech': this_tts_speech.cpu()} + with self.lock: + self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:] + # increase token_hop_len for better speech quality + token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) + if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: + break + p.join() + # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True) + yield {'tts_speech': this_tts_speech.cpu()} + else: + # deal with all tokens + p.join() + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True, + speed=speed) + yield {'tts_speech': this_tts_speech.cpu()} + with self.lock: + self.tts_speech_token_dict.pop(this_uuid) + self.llm_end_dict.pop(this_uuid) + self.mel_overlap_dict.pop(this_uuid) + self.hift_cache_dict.pop(this_uuid) + self.flow_cache_dict.pop(this_uuid) + + self.synchronize_stream() + self.stream_context_pool.put(stream_context) + torch.cuda.empty_cache() def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread @@ -278,6 +292,10 @@ class CosyVoiceModel: self.hift_cache_dict.pop(this_uuid) torch.cuda.empty_cache() + def synchronize_stream(self): + if self.is_cuda_available: + torch.cuda.current_stream().synchronize() + class CosyVoice2Model(CosyVoiceModel): @@ -314,19 +332,26 @@ class CosyVoice2Model(CosyVoiceModel): self.llm_end_dict = {} self.hift_cache_dict = {} + self.stream_context_pool = queue.Queue() + for _ in range(10): + self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()) + + self.is_cuda_available = torch.cuda.is_available() + def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0): + tts_mel, _ = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device), - finalize=finalize) + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + finalize=finalize) tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:] # append hift cache if self.hift_cache_dict[uuid] is not None: @@ -340,8 +365,8 @@ class CosyVoice2Model(CosyVoiceModel): if self.hift_cache_dict[uuid] is not None: tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:], - 'source': tts_source[:, :, -self.source_cache_len:], - 'speech': tts_speech[:, -self.source_cache_len:]} + 'source': tts_source[:, :, -self.source_cache_len:], + 'speech': tts_speech[:, -self.source_cache_len:]} tts_speech = tts_speech[:, :-self.source_cache_len] else: if speed != 1.0: @@ -358,57 +383,64 @@ class CosyVoice2Model(CosyVoiceModel): flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread - this_uuid = str(uuid.uuid1()) - with self.lock: - self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False - self.hift_cache_dict[this_uuid] = None - p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) - p.start() - if stream is True: - token_offset = 0 - while True: - time.sleep(0.1) - if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - token_offset=token_offset, - finalize=False) - token_offset += self.token_hop_len - yield {'tts_speech': this_tts_speech.cpu()} - if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len: - break - p.join() - # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - token_offset=token_offset, - finalize=True) - yield {'tts_speech': this_tts_speech.cpu()} - else: - # deal with all tokens - p.join() - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - token_offset=0, - finalize=True, - speed=speed) - yield {'tts_speech': this_tts_speech.cpu()} - with self.lock: - self.tts_speech_token_dict.pop(this_uuid) - self.llm_end_dict.pop(this_uuid) - torch.cuda.empty_cache() + self.synchronize_stream() + stream_context = self.stream_context_pool.get() + with torch.cuda.stream(stream_context): + + this_uuid = str(uuid.uuid1()) + with self.lock: + self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False + self.hift_cache_dict[this_uuid] = None + p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) + p.start() + if stream is True: + token_offset = 0 + while True: + time.sleep(0.1) + if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + token_offset=token_offset, + finalize=False) + token_offset += self.token_hop_len + yield {'tts_speech': this_tts_speech.cpu()} + if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len: + break + p.join() + # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + token_offset=token_offset, + finalize=True) + yield {'tts_speech': this_tts_speech.cpu()} + else: + # deal with all tokens + p.join() + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + token_offset=0, + finalize=True, + speed=speed) + yield {'tts_speech': this_tts_speech.cpu()} + with self.lock: + self.tts_speech_token_dict.pop(this_uuid) + self.llm_end_dict.pop(this_uuid) + + self.synchronize_stream() + self.stream_context_pool.put(stream_context) + torch.cuda.empty_cache() class VllmCosyVoice2Model(CosyVoice2Model): diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 6a60f6d..39643ed 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -15,7 +15,26 @@ import threading import torch import torch.nn.functional as F from matcha.models.components.flow_matching import BASECFM +import queue +class EstimatorWrapper: + def __init__(self, estimator_engine, estimator_count=2,): + self.estimators = queue.Queue() + self.estimator_engine = estimator_engine + for _ in range(estimator_count): + estimator = estimator_engine.create_execution_context() + if estimator is not None: + self.estimators.put(estimator) + + if self.estimators.empty(): + raise Exception("No available estimator") + + def acquire_estimator(self): + return self.estimators.get(), self.estimator_engine + + def release_estimator(self, estimator): + self.estimators.put(estimator) + return class ConditionalCFM(BASECFM): def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): @@ -125,22 +144,50 @@ class ConditionalCFM(BASECFM): if isinstance(self.estimator, torch.nn.Module): return self.estimator.forward(x, mask, mu, t, spks, cond) else: - with self.lock: - self.estimator.set_input_shape('x', (2, 80, x.size(2))) - self.estimator.set_input_shape('mask', (2, 1, x.size(2))) - self.estimator.set_input_shape('mu', (2, 80, x.size(2))) - self.estimator.set_input_shape('t', (2,)) - self.estimator.set_input_shape('spks', (2, 80)) - self.estimator.set_input_shape('cond', (2, 80, x.size(2))) + if isinstance(self.estimator, EstimatorWrapper): + estimator, engine = self.estimator.acquire_estimator() + + estimator.set_input_shape('x', (2, 80, x.size(2))) + estimator.set_input_shape('mask', (2, 1, x.size(2))) + estimator.set_input_shape('mu', (2, 80, x.size(2))) + estimator.set_input_shape('t', (2,)) + estimator.set_input_shape('spks', (2, 80)) + estimator.set_input_shape('cond', (2, 80, x.size(2))) + + data_ptrs = [x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()] + + for idx, data_ptr in enumerate(data_ptrs): + estimator.set_tensor_address(engine.get_tensor_name(idx), data_ptr) + # run trt engine - self.estimator.execute_v2([x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()]) - return x + estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) + + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator) + return x + else: + with self.lock: + self.estimator.set_input_shape('x', (2, 80, x.size(2))) + self.estimator.set_input_shape('mask', (2, 1, x.size(2))) + self.estimator.set_input_shape('mu', (2, 80, x.size(2))) + self.estimator.set_input_shape('t', (2,)) + self.estimator.set_input_shape('spks', (2, 80)) + self.estimator.set_input_shape('cond', (2, 80, x.size(2))) + # run trt engine + self.estimator.execute_v2([x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()]) + return x def compute_loss(self, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index ac7fe93..cf8ad03 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -61,7 +61,7 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16): network = builder.create_network(network_flags) parser = trt.OnnxParser(network, logger) config = builder.create_builder_config() - config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB if fp16: config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile()