mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
fix lint
This commit is contained in:
@@ -102,6 +102,7 @@ def init_weights(m, mean=0.0, std=0.01):
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
# Repetition Aware Sampling in VALL-E 2
|
||||
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
|
||||
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
||||
@@ -110,6 +111,7 @@ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25,
|
||||
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
|
||||
return top_ids
|
||||
|
||||
|
||||
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
||||
prob, indices = [], []
|
||||
cum_prob = 0.0
|
||||
@@ -127,13 +129,16 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
||||
top_ids = indices[prob.multinomial(1, replacement=True)]
|
||||
return top_ids
|
||||
|
||||
|
||||
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
||||
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
||||
return top_ids
|
||||
|
||||
|
||||
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
||||
device = fade_in_mel.device
|
||||
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
||||
mel_overlap_len = int(window.shape[0] / 2)
|
||||
fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
|
||||
fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + \
|
||||
fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
|
||||
return fade_in_mel.to(device)
|
||||
|
||||
Reference in New Issue
Block a user