update readme

This commit is contained in:
yuekaiz
2025-09-03 17:42:14 +08:00
parent e04699c6da
commit 633b991290
5 changed files with 143 additions and 77 deletions

View File

@@ -47,11 +47,11 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(level
logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
class CosyVoice2:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
self.model_dir = model_dir
self.fp16 = fp16
@@ -61,7 +61,7 @@ class CosyVoice2:
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
@@ -77,8 +77,9 @@ class CosyVoice2Model:
def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fp16: bool = False,
device: str = 'cuda'):
self.device = device
self.flow = flow
self.hift = hift
self.fp16 = fp16
@@ -179,11 +180,11 @@ class TritonPythonModel:
model_dir = model_params["model_dir"]
# Initialize device and vocoder
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
self.token2wav_model = CosyVoice2(
model_dir, load_jit=True, load_trt=True, fp16=True
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
)
logger.info("Token2Wav initialized successfully")
@@ -224,7 +225,6 @@ class TritonPythonModel:
else:
stream = False
request_id = request.request_id()
print(f"token_offset: {token_offset}, finalize: {finalize}, request_id: {request_id}")
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
prompt_token=prompt_speech_tokens,
prompt_feat=prompt_speech_feat,
@@ -234,7 +234,6 @@ class TritonPythonModel:
stream=stream,
finalize=finalize)
if finalize:
print(f"dict keys: {self.token2wav_model.model.hift_cache_dict.keys()}")
self.token2wav_model.model.hift_cache_dict.pop(request_id)
else: