mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update
This commit is contained in:
@@ -287,8 +287,6 @@ class CosyVoice2Model:
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||
self.llm.to(self.device).eval()
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
self.flow.decoder.fp16 = False
|
||||
@@ -319,8 +317,6 @@ class CosyVoice2Model:
|
||||
self.flow.decoder.fp16 = True
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
if self.fp16 is True:
|
||||
llm_embedding = llm_embedding.half()
|
||||
with self.llm_context:
|
||||
for i in self.llm.inference(text=text.to(self.device),
|
||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
|
||||
@@ -136,41 +136,26 @@ class ConditionalCFM(BASECFM):
|
||||
'mask': mask.cpu().numpy(),
|
||||
'mu': mu.cpu().numpy(),
|
||||
't': t.cpu().numpy(),
|
||||
'spk': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy(),
|
||||
'mask_rand': torch.randn(1, 1, 1).numpy()
|
||||
'spks': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy()
|
||||
}
|
||||
output = self.estimator.run(None, ort_inputs)[0]
|
||||
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
if not mask.is_contiguous():
|
||||
mask = mask.contiguous()
|
||||
if not mu.is_contiguous():
|
||||
mu = mu.contiguous()
|
||||
if not t.is_contiguous():
|
||||
t = t.contiguous()
|
||||
if not spks.is_contiguous():
|
||||
spks = spks.contiguous()
|
||||
if not cond.is_contiguous():
|
||||
cond = cond.contiguous()
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spk', (2, 80))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask_rand', (1, 1, 1))
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.data_ptr(),
|
||||
mask.data_ptr(),
|
||||
mu.data_ptr(),
|
||||
t.data_ptr(),
|
||||
spks.data_ptr(),
|
||||
cond.data_ptr(),
|
||||
torch.randn(1, 1, 1).to(x.device).data_ptr(),
|
||||
x.data_ptr()])
|
||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()])
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
@@ -241,7 +226,7 @@ class CausalConditionalCFM(ConditionalCFM):
|
||||
"""
|
||||
|
||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
||||
if self.sp16 is True:
|
||||
if self.fp16 is True:
|
||||
z = z.half()
|
||||
# fix prompt and overlap part mu and z
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
|
||||
Reference in New Issue
Block a user