Merge pull request #1758 from orbisai0security/fix/V-005-pickle-deserialization

[Security] Fix CRITICAL vulnerability: V-005
This commit is contained in:
Xiang Lyu
2025-12-31 11:45:44 +08:00
committed by GitHub
2 changed files with 4 additions and 4 deletions

View File

@@ -47,7 +47,7 @@ class CosyVoiceFrontEnd:
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
"CPUExecutionProvider"])
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device)
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
else:
self.spk2info = {}
self.allowed_special = allowed_special

View File

@@ -63,12 +63,12 @@ class CosyVoiceModel:
self.silent_tokens = []
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
self.llm.to(self.device).eval()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()