mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update lint
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
@@ -6,6 +5,7 @@ from typing import List, Optional, Tuple
|
||||
from einops import rearrange
|
||||
from torchaudio.transforms import Spectrogram
|
||||
|
||||
|
||||
class MultipleDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self, mpd: nn.Module, mrd: nn.Module
|
||||
@@ -28,6 +28,7 @@ class MultipleDiscriminator(nn.Module):
|
||||
fmap_gs += this_fmap_gs
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -112,7 +113,7 @@ class DiscriminatorR(nn.Module):
|
||||
x = torch.view_as_real(x)
|
||||
x = rearrange(x, "b f t c -> b c t f")
|
||||
# Split into bands
|
||||
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
||||
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
|
||||
return x_bands
|
||||
|
||||
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
||||
@@ -136,4 +137,4 @@ class DiscriminatorR(nn.Module):
|
||||
fmap.append(x)
|
||||
x += h
|
||||
|
||||
return x, fmap
|
||||
return x, fmap
|
||||
|
||||
Reference in New Issue
Block a user