This commit is contained in:
lyuxiang.lx
2024-12-16 09:54:24 +08:00
parent 6b5931dc70
commit ac70560364
8 changed files with 33 additions and 32 deletions

View File

@@ -120,6 +120,7 @@ class CosyVoice:
yield model_output
start_time = time.time()
class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
@@ -153,4 +154,4 @@ class CosyVoice2(CosyVoice):
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
del configs
del configs

View File

@@ -123,7 +123,7 @@ class CosyVoiceFrontEnd:
text = remove_bracket(text)
text = re.sub(r'[,、]+$', '', text)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
token_min_n=60, merge_len=20, comma_split=False))
else:
if self.use_ttsfrd:
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
@@ -132,7 +132,7 @@ class CosyVoiceFrontEnd:
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
token_min_n=60, merge_len=20, comma_split=False))
if split is False:
return text
return texts

View File

@@ -330,13 +330,13 @@ class CosyVoice2Model:
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
finalize=finalize)
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
if self.hift_cache_dict[uuid] is not None:
@@ -418,4 +418,4 @@ class CosyVoice2Model:
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)

View File

@@ -49,7 +49,7 @@ class CausalBlock1D(Block1D):
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
@@ -70,12 +70,11 @@ class CausalConv1d(torch.nn.Conv1d):
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype
)
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = (kernel_size - 1, 0)
@@ -124,7 +123,8 @@ class ConditionalDecoder(nn.Module):
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal \
else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
@@ -138,14 +138,16 @@ class ConditionalDecoder(nn.Module):
]
)
downsample = (
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
Downsample1D(output_channel) if not is_last else \
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[

View File

@@ -202,7 +202,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
@@ -211,7 +210,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
h, h_lengths = self.encoder(token, token_len)
if finalize is False:
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
# get conditions
@@ -229,4 +228,4 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat, None
return feat, None

View File

@@ -274,4 +274,4 @@ def get_qwen_tokenizer(
token_path: str,
skip_special_tokens: bool
) -> QwenTokenizer:
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)

View File

@@ -19,7 +19,6 @@ from typing import Tuple
import torch
from torch import nn
import torch.utils.checkpoint as ckpt
from torch.nn import functional as F
from cosyvoice.transformer.convolution import ConvolutionModule
@@ -49,14 +48,14 @@ class Upsample1D(nn.Module):
number of output channels. Defaults to `channels`.
"""
def __init__(self, channels: int, out_channels: int, stride: int=2):
def __init__(self, channels: int, out_channels: int, stride: int = 2):
super().__init__()
self.channels = channels
self.out_channels = out_channels
self.stride = stride
# In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d(
self.channels, self.out_channels, stride*2+1, stride=1,
self.channels, self.out_channels, stride * 2 + 1, stride = 1,
padding=0,
)
@@ -74,7 +73,7 @@ class PreLookaheadLayer(nn.Module):
self.pre_lookahead_len = pre_lookahead_len
self.conv1 = nn.Conv1d(
channels, channels,
kernel_size=pre_lookahead_len+1,
kernel_size=pre_lookahead_len + 1,
stride=1, padding=0,
)
self.conv2 = nn.Conv1d(
@@ -315,8 +314,8 @@ class UpsampleConformerEncoder(torch.nn.Module):
return xs
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
for layer in self.up_encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs

View File

@@ -163,4 +163,4 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
# NOTE(Mddct): torch.finfo jit issues
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
mask = (1.0 - mask) * torch.finfo(dtype).min
return mask
return mask