This commit is contained in:
lyuxiang.lx
2024-12-16 10:37:10 +08:00
parent ac70560364
commit 3581caec76
8 changed files with 24 additions and 22 deletions

View File

@@ -149,6 +149,11 @@ class CosyVoiceFrontEnd:
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2* token_len
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
embedding = self._extract_spk_embedding(prompt_speech_16k)
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,

View File

@@ -379,8 +379,7 @@ class CosyVoice2Model:
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]) \
.unsqueeze(dim=0)
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,

View File

@@ -123,8 +123,8 @@ class ConditionalDecoder(nn.Module):
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal \
else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
@@ -138,7 +138,7 @@ class ConditionalDecoder(nn.Module):
]
)
downsample = (
Downsample1D(output_channel) if not is_last else \
Downsample1D(output_channel) if not is_last else
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
@@ -147,7 +147,7 @@ class ConditionalDecoder(nn.Module):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
@@ -251,7 +251,7 @@ class ConditionalDecoder(nn.Module):
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@@ -270,7 +270,7 @@ class ConditionalDecoder(nn.Module):
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@@ -287,7 +287,7 @@ class ConditionalDecoder(nn.Module):
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@@ -298,4 +298,4 @@ class ConditionalDecoder(nn.Module):
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
return output * mask

View File

@@ -150,12 +150,12 @@ class ConditionalCFM(BASECFM):
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None):

View File

@@ -337,4 +337,4 @@ class Qwen2LM(torch.nn.Module):
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)

View File

@@ -269,6 +269,7 @@ class QwenTokenizer():
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
return text
@lru_cache(maxsize=None)
def get_qwen_tokenizer(
token_path: str,

View File

@@ -54,10 +54,7 @@ class Upsample1D(nn.Module):
self.out_channels = out_channels
self.stride = stride
# In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d(
self.channels, self.out_channels, stride * 2 + 1, stride = 1,
padding=0,
)
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")