mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix lint
This commit is contained in:
@@ -49,7 +49,7 @@ class CausalBlock1D(Block1D):
|
||||
|
||||
|
||||
class CausalResnetBlock1D(ResnetBlock1D):
|
||||
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8):
|
||||
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
||||
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
||||
self.block1 = CausalBlock1D(dim, dim_out)
|
||||
self.block2 = CausalBlock1D(dim_out, dim_out)
|
||||
@@ -70,12 +70,11 @@ class CausalConv1d(torch.nn.Conv1d):
|
||||
dtype=None
|
||||
) -> None:
|
||||
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
||||
kernel_size, stride,
|
||||
padding=0, dilation=dilation,
|
||||
groups=groups, bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
kernel_size, stride,
|
||||
padding=0, dilation=dilation,
|
||||
groups=groups, bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device, dtype=dtype)
|
||||
assert stride == 1
|
||||
self.causal_padding = (kernel_size - 1, 0)
|
||||
|
||||
@@ -124,7 +123,8 @@ class ConditionalDecoder(nn.Module):
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal \
|
||||
else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
@@ -138,14 +138,16 @@ class ConditionalDecoder(nn.Module):
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
Downsample1D(output_channel) if not is_last else \
|
||||
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||
|
||||
for _ in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
||||
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user