feat: v1.5 gradio for windows&linux

This commit is contained in:
zzzweakman
2025-04-11 02:43:04 +08:00
parent 2e5b74a257
commit b9b459a119
8 changed files with 330 additions and 169 deletions

View File

@@ -49,8 +49,9 @@ class AudioProcessor:
whisper_feature = []
# Process multiple 30s mel input features
for input_feature in whisper_input_features:
audio_feats = whisper.encoder(input_feature.to(device), output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype)
input_feature = input_feature.to(device).to(weight_dtype)
audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2)
whisper_feature.append(audio_feats)
whisper_feature = torch.cat(whisper_feature, dim=1)

View File

@@ -13,9 +13,9 @@ from musetalk.models.unet import UNet,PositionalEncoding
def load_all_model(
unet_model_path=os.path.join("models", "musetalk", "pytorch_model.bin"),
unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
vae_type="sd-vae",
unet_config=os.path.join("models", "musetalk", "musetalk.json"),
unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
device=None,
):
vae = VAE(