This commit is contained in:
lyuxiang.lx
2024-12-16 09:54:24 +08:00
parent 6b5931dc70
commit ac70560364
8 changed files with 33 additions and 32 deletions

View File

@@ -120,6 +120,7 @@ class CosyVoice:
yield model_output yield model_output
start_time = time.time() start_time = time.time()
class CosyVoice2(CosyVoice): class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False): def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):

View File

@@ -49,7 +49,7 @@ class CausalBlock1D(Block1D):
class CausalResnetBlock1D(ResnetBlock1D): class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8): def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out) self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out) self.block2 = CausalBlock1D(dim_out, dim_out)
@@ -74,8 +74,7 @@ class CausalConv1d(torch.nn.Conv1d):
padding=0, dilation=dilation, padding=0, dilation=dilation,
groups=groups, bias=bias, groups=groups, bias=bias,
padding_mode=padding_mode, padding_mode=padding_mode,
device=device, dtype=dtype device=device, dtype=dtype)
)
assert stride == 1 assert stride == 1
self.causal_padding = (kernel_size - 1, 0) self.causal_padding = (kernel_size - 1, 0)
@@ -124,7 +123,8 @@ class ConditionalDecoder(nn.Module):
input_channel = output_channel input_channel = output_channel
output_channel = channels[i] output_channel = channels[i]
is_last = i == len(channels) - 1 is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal \
else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [
BasicTransformerBlock( BasicTransformerBlock(
@@ -138,14 +138,16 @@ class ConditionalDecoder(nn.Module):
] ]
) )
downsample = ( downsample = (
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) Downsample1D(output_channel) if not is_last else \
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
) )
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks): for _ in range(num_mid_blocks):
input_channel = channels[-1] input_channel = channels[-1]
out_channels = channels[-1] out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [

View File

@@ -202,7 +202,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
embedding = self.spk_embed_affine_layer(embedding) embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text # concat text and prompt_text
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask token = self.input_embedding(torch.clamp(token, min=0)) * mask

View File

@@ -19,7 +19,6 @@ from typing import Tuple
import torch import torch
from torch import nn from torch import nn
import torch.utils.checkpoint as ckpt
from torch.nn import functional as F from torch.nn import functional as F
from cosyvoice.transformer.convolution import ConvolutionModule from cosyvoice.transformer.convolution import ConvolutionModule
@@ -49,14 +48,14 @@ class Upsample1D(nn.Module):
number of output channels. Defaults to `channels`. number of output channels. Defaults to `channels`.
""" """
def __init__(self, channels: int, out_channels: int, stride: int=2): def __init__(self, channels: int, out_channels: int, stride: int = 2):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels self.out_channels = out_channels
self.stride = stride self.stride = stride
# In this mode, first repeat interpolate, than conv with stride=1 # In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d( self.conv = nn.Conv1d(
self.channels, self.out_channels, stride*2+1, stride=1, self.channels, self.out_channels, stride * 2 + 1, stride = 1,
padding=0, padding=0,
) )
@@ -74,7 +73,7 @@ class PreLookaheadLayer(nn.Module):
self.pre_lookahead_len = pre_lookahead_len self.pre_lookahead_len = pre_lookahead_len
self.conv1 = nn.Conv1d( self.conv1 = nn.Conv1d(
channels, channels, channels, channels,
kernel_size=pre_lookahead_len+1, kernel_size=pre_lookahead_len + 1,
stride=1, padding=0, stride=1, padding=0,
) )
self.conv2 = nn.Conv1d( self.conv2 = nn.Conv1d(