This commit is contained in:
lyuxiang.lx
2024-12-16 09:54:24 +08:00
parent 6b5931dc70
commit ac70560364
8 changed files with 33 additions and 32 deletions

View File

@@ -120,6 +120,7 @@ class CosyVoice:
yield model_output yield model_output
start_time = time.time() start_time = time.time()
class CosyVoice2(CosyVoice): class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False): def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):

View File

@@ -74,8 +74,7 @@ class CausalConv1d(torch.nn.Conv1d):
padding=0, dilation=dilation, padding=0, dilation=dilation,
groups=groups, bias=bias, groups=groups, bias=bias,
padding_mode=padding_mode, padding_mode=padding_mode,
device=device, dtype=dtype device=device, dtype=dtype)
)
assert stride == 1 assert stride == 1
self.causal_padding = (kernel_size - 1, 0) self.causal_padding = (kernel_size - 1, 0)
@@ -124,7 +123,8 @@ class ConditionalDecoder(nn.Module):
input_channel = output_channel input_channel = output_channel
output_channel = channels[i] output_channel = channels[i]
is_last = i == len(channels) - 1 is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal \
else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [
BasicTransformerBlock( BasicTransformerBlock(
@@ -138,14 +138,16 @@ class ConditionalDecoder(nn.Module):
] ]
) )
downsample = ( downsample = (
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) Downsample1D(output_channel) if not is_last else \
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
) )
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks): for _ in range(num_mid_blocks):
input_channel = channels[-1] input_channel = channels[-1]
out_channels = channels[-1] out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [

View File

@@ -202,7 +202,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
embedding = self.spk_embed_affine_layer(embedding) embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text # concat text and prompt_text
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask token = self.input_embedding(torch.clamp(token, min=0)) * mask

View File

@@ -19,7 +19,6 @@ from typing import Tuple
import torch import torch
from torch import nn from torch import nn
import torch.utils.checkpoint as ckpt
from torch.nn import functional as F from torch.nn import functional as F
from cosyvoice.transformer.convolution import ConvolutionModule from cosyvoice.transformer.convolution import ConvolutionModule