mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
Merge pull request #1758 from orbisai0security/fix/V-005-pickle-deserialization
[Security] Fix CRITICAL vulnerability: V-005
This commit is contained in:
@@ -47,7 +47,7 @@ class CosyVoiceFrontEnd:
|
|||||||
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
||||||
"CPUExecutionProvider"])
|
"CPUExecutionProvider"])
|
||||||
if os.path.exists(spk2info):
|
if os.path.exists(spk2info):
|
||||||
self.spk2info = torch.load(spk2info, map_location=self.device)
|
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
|
||||||
else:
|
else:
|
||||||
self.spk2info = {}
|
self.spk2info = {}
|
||||||
self.allowed_special = allowed_special
|
self.allowed_special = allowed_special
|
||||||
|
|||||||
@@ -63,12 +63,12 @@ class CosyVoiceModel:
|
|||||||
self.silent_tokens = []
|
self.silent_tokens = []
|
||||||
|
|
||||||
def load(self, llm_model, flow_model, hift_model):
|
def load(self, llm_model, flow_model, hift_model):
|
||||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
|
||||||
self.llm.to(self.device).eval()
|
self.llm.to(self.device).eval()
|
||||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
|
||||||
self.flow.to(self.device).eval()
|
self.flow.to(self.device).eval()
|
||||||
# in case hift_model is a hifigan model
|
# in case hift_model is a hifigan model
|
||||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
|
||||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||||
self.hift.to(self.device).eval()
|
self.hift.to(self.device).eval()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user