mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update readme
This commit is contained in:
@@ -47,11 +47,11 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(level
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ORIGINAL_VOCAB_SIZE = 151663
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
class CosyVoice2:
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
|
||||
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
@@ -61,7 +61,7 @@ class CosyVoice2:
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
|
||||
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
|
||||
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
@@ -77,8 +77,9 @@ class CosyVoice2Model:
|
||||
def __init__(self,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
fp16: bool = False,
|
||||
device: str = 'cuda'):
|
||||
self.device = device
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
@@ -179,11 +180,11 @@ class TritonPythonModel:
|
||||
model_dir = model_params["model_dir"]
|
||||
|
||||
# Initialize device and vocoder
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
||||
|
||||
self.token2wav_model = CosyVoice2(
|
||||
model_dir, load_jit=True, load_trt=True, fp16=True
|
||||
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
|
||||
)
|
||||
|
||||
logger.info("Token2Wav initialized successfully")
|
||||
@@ -224,7 +225,6 @@ class TritonPythonModel:
|
||||
else:
|
||||
stream = False
|
||||
request_id = request.request_id()
|
||||
print(f"token_offset: {token_offset}, finalize: {finalize}, request_id: {request_id}")
|
||||
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
|
||||
prompt_token=prompt_speech_tokens,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
@@ -234,7 +234,6 @@ class TritonPythonModel:
|
||||
stream=stream,
|
||||
finalize=finalize)
|
||||
if finalize:
|
||||
print(f"dict keys: {self.token2wav_model.model.hift_cache_dict.keys()}")
|
||||
self.token2wav_model.model.hift_cache_dict.pop(request_id)
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user