mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 01:49:25 +08:00
add flow trt wrapper
This commit is contained in:
@@ -137,7 +137,7 @@ class CosyVoice:
|
||||
|
||||
class CosyVoice2(CosyVoice):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False, trt_concurrent=1):
|
||||
self.instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
@@ -159,7 +159,7 @@ class CosyVoice2(CosyVoice):
|
||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache, trt_concurrent)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Generator
|
||||
import queue
|
||||
import torch
|
||||
import numpy as np
|
||||
import threading
|
||||
@@ -22,6 +24,7 @@ from contextlib import nullcontext
|
||||
import uuid
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
||||
from cosyvoice.utils.common import TrtContextWrapper
|
||||
|
||||
|
||||
class CosyVoiceModel:
|
||||
@@ -89,9 +92,12 @@ class CosyVoiceModel:
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
if isinstance(self, CosyVoice2Model):
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
|
||||
else:
|
||||
self.flow.decoder.estimator = estimator_engine.create_execution_context()
|
||||
|
||||
def get_trt_kwargs(self):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||
@@ -231,7 +237,9 @@ class CosyVoiceModel:
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
|
||||
class CosyVoice2Model(CosyVoiceModel):
|
||||
@@ -241,13 +249,15 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False,
|
||||
use_flow_cache: bool = False):
|
||||
use_flow_cache: bool = False,
|
||||
trt_concurrent: int = 1):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
self.use_flow_cache = use_flow_cache
|
||||
self.trt_concurrent = trt_concurrent
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
@@ -261,12 +271,16 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||
# rtf and decoding related
|
||||
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
for _ in range(trt_concurrent):
|
||||
self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
||||
self.lock = threading.Lock()
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
self.trt_context_dict = {}
|
||||
|
||||
def init_flow_cache(self):
|
||||
encoder_cache = {'offset': 0,
|
||||
@@ -304,7 +318,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
|
||||
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
@@ -349,6 +363,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.flow_cache_dict[this_uuid] = self.init_flow_cache()
|
||||
self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
|
||||
if source_speech_token.shape[1] == 0:
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
else:
|
||||
@@ -405,4 +420,8 @@ class CosyVoice2Model(CosyVoiceModel):
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
torch.cuda.empty_cache()
|
||||
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
|
||||
self.trt_context_dict.pop(this_uuid)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
Reference in New Issue
Block a user