This commit is contained in:
lyuxiang.lx
2024-12-16 10:37:10 +08:00
parent ac70560364
commit 3581caec76
8 changed files with 24 additions and 22 deletions

View File

@@ -123,8 +123,8 @@ class ConditionalDecoder(nn.Module):
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal \
else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
@@ -138,7 +138,7 @@ class ConditionalDecoder(nn.Module):
]
)
downsample = (
Downsample1D(output_channel) if not is_last else \
Downsample1D(output_channel) if not is_last else
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
@@ -147,7 +147,7 @@ class ConditionalDecoder(nn.Module):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
@@ -251,7 +251,7 @@ class ConditionalDecoder(nn.Module):
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@@ -270,7 +270,7 @@ class ConditionalDecoder(nn.Module):
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@@ -287,7 +287,7 @@ class ConditionalDecoder(nn.Module):
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@@ -298,4 +298,4 @@ class ConditionalDecoder(nn.Module):
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
return output * mask