This commit is contained in:
lyuxiang.lx
2024-12-12 15:43:17 +08:00
parent 0bf706c26f
commit 2345ce6be2
4 changed files with 19 additions and 9 deletions

View File

@@ -157,6 +157,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
token_mel_ratio: int = 2,
pre_lookahead_len: int = 3,
encoder: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
@@ -181,6 +183,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len
@torch.inference_mode()
def inference(self,
@@ -206,7 +210,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
# text encode
h, h_lengths = self.encoder(token, token_len)
if finalize is False:
h = h[:, :-self.encoder.pre_lookahead_layer.pre_lookahead_len * self.encoder.up_layer.stride]
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)

View File

@@ -240,6 +240,8 @@ def get_tokenizer(
class QwenTokenizer():
def __init__(self, token_path, skip_special_tokens=True):
super().__init__()
# NOTE: non-chat model, all these special tokens keep randomly initialized.
special_tokens = {
'eos_token': '<|endoftext|>',
'pad_token': '<|endoftext|>',
@@ -248,6 +250,9 @@ class QwenTokenizer():
'[breath]', '<strong>', '</strong>', '[noise]',
'[laughter]', '[cough]', '[clucking]', '[accent]',
'[quick_breath]',
"<laughter>", "</laughter>",
"[hissing]", "[sigh]", "[vocalized-noise]",
"[lipsmack]", "[mn]"
]
}
self.tokenizer = AutoTokenizer.from_pretrained(token_path)