mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
Merge pull request #1814 from hexisyztem/main
[BUG FIX] 使用 float64 避免精度误差问题,弃用 CPU 计算,避免拖累性能
This commit is contained in:
@@ -713,8 +713,8 @@ class CausalHiFTGenerator(HiFTGenerator):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
|
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
|
||||||
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
|
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
|
||||||
self.f0_predictor.to('cpu')
|
self.f0_predictor.to(torch.float64)
|
||||||
f0 = self.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat)
|
f0 = self.f0_predictor(speech_feat.to(torch.float64), finalize=finalize).to(speech_feat)
|
||||||
# f0->source
|
# f0->source
|
||||||
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||||
s, _, _ = self.m_source(s)
|
s, _, _ = self.m_source(s)
|
||||||
|
|||||||
Reference in New Issue
Block a user