mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -120,6 +120,7 @@ class CosyVoice:
|
|||||||
yield model_output
|
yield model_output
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
class CosyVoice2(CosyVoice):
|
class CosyVoice2(CosyVoice):
|
||||||
|
|
||||||
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
|
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
|
||||||
|
|||||||
@@ -74,8 +74,7 @@ class CausalConv1d(torch.nn.Conv1d):
|
|||||||
padding=0, dilation=dilation,
|
padding=0, dilation=dilation,
|
||||||
groups=groups, bias=bias,
|
groups=groups, bias=bias,
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
device=device, dtype=dtype
|
device=device, dtype=dtype)
|
||||||
)
|
|
||||||
assert stride == 1
|
assert stride == 1
|
||||||
self.causal_padding = (kernel_size - 1, 0)
|
self.causal_padding = (kernel_size - 1, 0)
|
||||||
|
|
||||||
@@ -124,7 +123,8 @@ class ConditionalDecoder(nn.Module):
|
|||||||
input_channel = output_channel
|
input_channel = output_channel
|
||||||
output_channel = channels[i]
|
output_channel = channels[i]
|
||||||
is_last = i == len(channels) - 1
|
is_last = i == len(channels) - 1
|
||||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal \
|
||||||
|
else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||||
transformer_blocks = nn.ModuleList(
|
transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
BasicTransformerBlock(
|
BasicTransformerBlock(
|
||||||
@@ -138,14 +138,16 @@ class ConditionalDecoder(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
downsample = (
|
downsample = (
|
||||||
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
Downsample1D(output_channel) if not is_last else \
|
||||||
|
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||||
)
|
)
|
||||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||||
|
|
||||||
for _ in range(num_mid_blocks):
|
for _ in range(num_mid_blocks):
|
||||||
input_channel = channels[-1]
|
input_channel = channels[-1]
|
||||||
out_channels = channels[-1]
|
out_channels = channels[-1]
|
||||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
||||||
|
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||||
|
|
||||||
transformer_blocks = nn.ModuleList(
|
transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -202,7 +202,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
embedding = self.spk_embed_affine_layer(embedding)
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
|
|
||||||
# concat text and prompt_text
|
# concat text and prompt_text
|
||||||
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
|
||||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from typing import Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.utils.checkpoint as ckpt
|
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from cosyvoice.transformer.convolution import ConvolutionModule
|
from cosyvoice.transformer.convolution import ConvolutionModule
|
||||||
|
|||||||
Reference in New Issue
Block a user