mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update
This commit is contained in:
@@ -124,7 +124,7 @@ from cosyvoice.utils.file_utils import load_wav
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
## cosyvoice2 usage
|
## cosyvoice2 usage
|
||||||
cosyvoice2 = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, load_trt=False)
|
cosyvoice2 = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_onnx=False, load_trt=False)
|
||||||
# sft usage
|
# sft usage
|
||||||
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
||||||
for i, j in enumerate(cosyvoice2.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=True)):
|
for i, j in enumerate(cosyvoice2.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=True)):
|
||||||
|
|||||||
@@ -287,8 +287,6 @@ class CosyVoice2Model:
|
|||||||
def load(self, llm_model, flow_model, hift_model):
|
def load(self, llm_model, flow_model, hift_model):
|
||||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||||
self.llm.to(self.device).eval()
|
self.llm.to(self.device).eval()
|
||||||
if self.fp16 is True:
|
|
||||||
self.llm.half()
|
|
||||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||||
self.flow.to(self.device).eval()
|
self.flow.to(self.device).eval()
|
||||||
self.flow.decoder.fp16 = False
|
self.flow.decoder.fp16 = False
|
||||||
@@ -319,8 +317,6 @@ class CosyVoice2Model:
|
|||||||
self.flow.decoder.fp16 = True
|
self.flow.decoder.fp16 = True
|
||||||
|
|
||||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||||
if self.fp16 is True:
|
|
||||||
llm_embedding = llm_embedding.half()
|
|
||||||
with self.llm_context:
|
with self.llm_context:
|
||||||
for i in self.llm.inference(text=text.to(self.device),
|
for i in self.llm.inference(text=text.to(self.device),
|
||||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
|||||||
@@ -136,41 +136,26 @@ class ConditionalCFM(BASECFM):
|
|||||||
'mask': mask.cpu().numpy(),
|
'mask': mask.cpu().numpy(),
|
||||||
'mu': mu.cpu().numpy(),
|
'mu': mu.cpu().numpy(),
|
||||||
't': t.cpu().numpy(),
|
't': t.cpu().numpy(),
|
||||||
'spk': spks.cpu().numpy(),
|
'spks': spks.cpu().numpy(),
|
||||||
'cond': cond.cpu().numpy(),
|
'cond': cond.cpu().numpy()
|
||||||
'mask_rand': torch.randn(1, 1, 1).numpy()
|
|
||||||
}
|
}
|
||||||
output = self.estimator.run(None, ort_inputs)[0]
|
output = self.estimator.run(None, ort_inputs)[0]
|
||||||
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
||||||
else:
|
else:
|
||||||
if not x.is_contiguous():
|
|
||||||
x = x.contiguous()
|
|
||||||
if not mask.is_contiguous():
|
|
||||||
mask = mask.contiguous()
|
|
||||||
if not mu.is_contiguous():
|
|
||||||
mu = mu.contiguous()
|
|
||||||
if not t.is_contiguous():
|
|
||||||
t = t.contiguous()
|
|
||||||
if not spks.is_contiguous():
|
|
||||||
spks = spks.contiguous()
|
|
||||||
if not cond.is_contiguous():
|
|
||||||
cond = cond.contiguous()
|
|
||||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||||
self.estimator.set_input_shape('t', (2,))
|
self.estimator.set_input_shape('t', (2,))
|
||||||
self.estimator.set_input_shape('spk', (2, 80))
|
self.estimator.set_input_shape('spks', (2, 80))
|
||||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||||
self.estimator.set_input_shape('mask_rand', (1, 1, 1))
|
|
||||||
# run trt engine
|
# run trt engine
|
||||||
self.estimator.execute_v2([x.data_ptr(),
|
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||||
mask.data_ptr(),
|
mask.contiguous().data_ptr(),
|
||||||
mu.data_ptr(),
|
mu.contiguous().data_ptr(),
|
||||||
t.data_ptr(),
|
t.contiguous().data_ptr(),
|
||||||
spks.data_ptr(),
|
spks.contiguous().data_ptr(),
|
||||||
cond.data_ptr(),
|
cond.contiguous().data_ptr(),
|
||||||
torch.randn(1, 1, 1).to(x.device).data_ptr(),
|
x.data_ptr()])
|
||||||
x.data_ptr()])
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||||
@@ -241,7 +226,7 @@ class CausalConditionalCFM(ConditionalCFM):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
||||||
if self.sp16 is True:
|
if self.fp16 is True:
|
||||||
z = z.half()
|
z = z.half()
|
||||||
# fix prompt and overlap part mu and z
|
# fix prompt and overlap part mu and z
|
||||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||||
|
conformer==0.3.2
|
||||||
deepspeed==0.14.2; sys_platform == 'linux'
|
deepspeed==0.14.2; sys_platform == 'linux'
|
||||||
diffusers==0.27.2
|
diffusers==0.27.2
|
||||||
gdown==5.1.0
|
gdown==5.1.0
|
||||||
@@ -16,8 +17,8 @@ modelscope==1.15.0
|
|||||||
networkx==3.1
|
networkx==3.1
|
||||||
omegaconf==2.3.0
|
omegaconf==2.3.0
|
||||||
onnx==1.16.0
|
onnx==1.16.0
|
||||||
onnxruntime-gpu==1.16.0; sys_platform == 'linux'
|
onnxruntime-gpu==1.18.0; sys_platform == 'linux'
|
||||||
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
|
onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'windows'
|
||||||
openai-whisper==20231117
|
openai-whisper==20231117
|
||||||
protobuf==4.25
|
protobuf==4.25
|
||||||
pydantic==2.7.0
|
pydantic==2.7.0
|
||||||
@@ -25,8 +26,11 @@ rich==13.7.1
|
|||||||
soundfile==0.12.1
|
soundfile==0.12.1
|
||||||
tensorboard==2.14.0
|
tensorboard==2.14.0
|
||||||
tensorrt-cu12==10.0.1
|
tensorrt-cu12==10.0.1
|
||||||
|
tensorrt-cu12-bindings==10.0.1
|
||||||
|
tensorrt-cu12-libs==10.0.1
|
||||||
torch==2.3.1
|
torch==2.3.1
|
||||||
torchaudio==2.3.1
|
torchaudio==2.3.1
|
||||||
|
transformers==4.40.1
|
||||||
uvicorn==0.30.0
|
uvicorn==0.30.0
|
||||||
wget==3.2
|
wget==3.2
|
||||||
fastapi==0.111.0
|
fastapi==0.111.0
|
||||||
|
|||||||
Reference in New Issue
Block a user