diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index fb1cd7f..71351a2 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -140,7 +140,7 @@ class CosyVoice: class CosyVoice2(CosyVoice): - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1): + def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 19bedd3..6ebbe52 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -59,9 +59,6 @@ class CosyVoiceModel: self.stream_scale_factor = 1 assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() - self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) - for _ in range(trt_concurrent): - self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()) self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -69,7 +66,6 @@ class CosyVoiceModel: self.mel_overlap_dict = {} self.flow_cache_dict = {} self.hift_cache_dict = {} - self.trt_context_dict = {} def load(self, llm_model, flow_model, hift_model): self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) @@ -98,7 +94,7 @@ class CosyVoiceModel: with open(flow_decoder_estimator_model, 'rb') as f: estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) - self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent) + self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent, device=self.device) def get_trt_kwargs(self): min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] @@ -125,7 +121,8 @@ class CosyVoiceModel: prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), - embedding=llm_embedding.to(self.device)): + embedding=llm_embedding.to(self.device), + uuid=uuid): self.tts_speech_token_dict[uuid].append(i) self.llm_end_dict[uuid] = True @@ -180,13 +177,11 @@ class CosyVoiceModel: prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) - this_trt_context = self.trt_context_pool.get() with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) - self.trt_context_dict[this_uuid] = this_trt_context if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) else: @@ -240,8 +235,6 @@ class CosyVoiceModel: self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) - self.trt_context_pool.put(self.trt_context_dict[this_uuid]) - self.trt_context_dict.pop(this_uuid) if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.current_stream().synchronize() @@ -273,22 +266,28 @@ class CosyVoice2Model(CosyVoiceModel): self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() - self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) - for _ in range(trt_concurrent): - self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()) self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} self.hift_cache_dict = {} - self.trt_context_dict = {} def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder + def load_vllm(self, model_dir): + export_cosyvoice2_vllm(self.llm, model_dir, self.device) + from vllm import EngineArgs, LLMEngine + engine_args = EngineArgs(model=model_dir, + skip_tokenizer_init=True, + enable_prompt_embeds=True, + gpu_memory_utilization=0.2) + self.llm.vllm = LLMEngine.from_engine_args(engine_args) + del self.llm.llm.model.model.layers + def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): - with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]: + with torch.cuda.amp.autocast(self.fp16): tts_mel, _ = self.flow.inference(token=token.to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), prompt_token=prompt_token.to(self.device), @@ -330,11 +329,9 @@ class CosyVoice2Model(CosyVoiceModel): prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs): # this_uuid is used to track variables related to this inference thread this_uuid = str(uuid.uuid1()) - this_trt_context = self.trt_context_pool.get() with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None - self.trt_context_dict[this_uuid] = this_trt_context if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) else: @@ -388,8 +385,6 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) - self.trt_context_pool.put(self.trt_context_dict[this_uuid]) - self.trt_context_dict.pop(this_uuid) if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.current_stream().synchronize() diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 704ced3..9f7d0be 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -16,6 +16,7 @@ import threading import torch import torch.nn.functional as F from matcha.models.components.flow_matching import BASECFM +from cosyvoice.utils.common import set_all_random_seed class ConditionalCFM(BASECFM): @@ -32,7 +33,6 @@ class ConditionalCFM(BASECFM): in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) # Just change the architecture of the estimator here self.estimator = estimator - self.lock = threading.Lock() @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)): @@ -127,26 +127,27 @@ class ConditionalCFM(BASECFM): if isinstance(self.estimator, torch.nn.Module): return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) else: - estimator, trt_engine = self.estimator.acquire_estimator() - estimator.set_input_shape('x', (2, 80, x.size(2))) - estimator.set_input_shape('mask', (2, 1, x.size(2))) - estimator.set_input_shape('mu', (2, 80, x.size(2))) - estimator.set_input_shape('t', (2,)) - estimator.set_input_shape('spks', (2, 80)) - estimator.set_input_shape('cond', (2, 80, x.size(2))) - data_ptrs = [x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()] - for i, j in enumerate(data_ptrs): - estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) - # run trt engine - assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True - torch.cuda.current_stream().synchronize() - self.estimator.release_estimator(estimator) + [estimator, stream], trt_engine = self.estimator.acquire_estimator() + with stream: + estimator.set_input_shape('x', (2, 80, x.size(2))) + estimator.set_input_shape('mask', (2, 1, x.size(2))) + estimator.set_input_shape('mu', (2, 80, x.size(2))) + estimator.set_input_shape('t', (2,)) + estimator.set_input_shape('spks', (2, 80)) + estimator.set_input_shape('cond', (2, 80, x.size(2))) + data_ptrs = [x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr()] + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + # run trt engine + assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True + torch.cuda.current_stream().synchronize() + self.estimator.release_estimator(estimator, stream) return x def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): @@ -194,6 +195,7 @@ class ConditionalCFM(BASECFM): class CausalConditionalCFM(ConditionalCFM): def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator) + set_all_random_seed(0) self.rand_noise = torch.randn([1, 80, 50 * 300]) @torch.inference_mode() diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index 670ae69..c5899ac 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import queue import random +import time +import threading from typing import Dict, Optional, Callable, List, Generator import torch from torch import nn @@ -170,6 +173,7 @@ class TransformerLM(torch.nn.Module): sampling: int = 25, max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, + uuid: str = '', ) -> Generator[torch.Tensor, None, None]: device = text.device text = torch.concat([prompt_text, text], dim=1) @@ -270,7 +274,6 @@ class Qwen2LM(TransformerLM): self.llm_input_size = llm_input_size self.llm_output_size = llm_output_size self.speech_token_size = speech_token_size - # 2. build speech token language model related modules self.sos_eos = 0 self.task_id = 1 @@ -292,6 +295,11 @@ class Qwen2LM(TransformerLM): # 4. sampling method self.sampling = sampling self.mix_ratio = mix_ratio + + # 5. vllm related + self.stop_token_ids = [speech_token_size + i for i in range(3)] + self.vllm_output_queue = {} + self.lock = threading.Lock() def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len): lm_target, lm_input = [], [] @@ -382,6 +390,7 @@ class Qwen2LM(TransformerLM): sampling: int = 25, max_token_text_ratio: float = 20, min_token_text_ratio: float = 2, + uuid: str = '', ) -> Generator[torch.Tensor, None, None]: device = text.device text = torch.concat([prompt_text, text], dim=1) @@ -402,22 +411,55 @@ class Qwen2LM(TransformerLM): max_len = int((text_len - prompt_text_len) * max_token_text_ratio) # 5. step by step decode - out_tokens = [] - cache = None - for i in range(max_len): - y_pred, cache = self.llm.forward_one_step(lm_input, - masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), - cache=cache) - logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) - top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() - if top_ids == self.speech_token_size: - break - if top_ids > self.speech_token_size: - continue - # in stream mode, yield token one by one - yield top_ids - out_tokens.append(top_ids) - lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid): + yield token + + @torch.inference_mode() + def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid): + if hasattr(self, 'vllm'): + from vllm import SamplingParams, RequestOutput + sampling_params = SamplingParams(top_k=sampling, + stop_token_ids=self.stop_token_ids, + min_tokens=min_len, + max_tokens=max_len) + with self.lock: + self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params) + self.vllm_output_queue[uuid] = queue.Queue() + out_tokens = [] + while True: + with self.lock: + if self.vllm_output_queue[uuid].empty() is True: + request_outputs: List[RequestOutput] = self.vllm.step() + for request_output in request_outputs: + top_ids = list(request_output.outputs[0].token_ids)[-1] + self.vllm_output_queue[request_output.request_id].put(top_ids) + if self.vllm_output_queue[uuid].empty() is False: + top_ids = self.vllm_output_queue[uuid].get() + if top_ids in self.stop_token_ids: + break + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + time.sleep(0.001) + with self.lock: + self.vllm_output_queue.pop(uuid) + else: + out_tokens = [] + cache = None + for i in range(max_len): + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() + if top_ids == self.speech_token_size: + break + if top_ids > self.speech_token_size: + continue + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) @torch.inference_mode() def inference_bistream( diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 088ca69..6f5a3dd 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -169,17 +169,18 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: class TrtContextWrapper: - def __init__(self, trt_engine, trt_concurrent=1): - self.trt_context_pool = queue.Queue() + def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) self.trt_engine = trt_engine for _ in range(trt_concurrent): trt_context = trt_engine.create_execution_context() + trt_stream = torch.cuda.stream(torch.cuda.Stream(device)) assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) - self.trt_context_pool.put(trt_context) + self.trt_context_pool.put([trt_context, trt_stream]) assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' def acquire_estimator(self): return self.trt_context_pool.get(), self.trt_engine - def release_estimator(self, context): - self.trt_context_pool.put(context) + def release_estimator(self, context, stream): + self.trt_context_pool.put([context, stream]) diff --git a/cosyvoice/utils/file_utils.py b/cosyvoice/utils/file_utils.py index fb849e6..1fbddae 100644 --- a/cosyvoice/utils/file_utils.py +++ b/cosyvoice/utils/file_utils.py @@ -58,7 +58,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): network = builder.create_network(network_flags) parser = trt.OnnxParser(network, logger) config = builder.create_builder_config() - config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB if fp16: config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile() @@ -122,6 +122,7 @@ def export_cosyvoice2_vllm(model, model_path, device): model.llm.model.config.tie_word_embeddings = False model.llm.model.config.use_bias = True model.llm.model.save_pretrained(model_path) + os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path))) model.llm.model.config.vocab_size = tmp_vocab_size model.llm.model.config.tie_word_embeddings = tmp_tie_embedding model.llm.model.set_input_embeddings(embed_tokens)