mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix pitch computation
This commit is contained in:
@@ -20,6 +20,7 @@ import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
import pyworld as pw
|
||||
|
||||
|
||||
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
||||
@@ -178,7 +179,7 @@ def compute_fbank(data,
|
||||
yield sample
|
||||
|
||||
|
||||
def compute_f0(data, pitch_extractor, mode='train'):
|
||||
def compute_f0(data, sample_rate, hop_size, mode='train'):
|
||||
""" Extract f0
|
||||
|
||||
Args:
|
||||
@@ -187,15 +188,19 @@ def compute_f0(data, pitch_extractor, mode='train'):
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
frame_period = hop_size * 1000 / sample_rate
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'speech' in sample
|
||||
assert 'utt' in sample
|
||||
assert 'text_token' in sample
|
||||
waveform = sample['speech']
|
||||
mat = pitch_extractor(waveform).transpose(1, 2)
|
||||
mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
|
||||
sample['pitch_feat'] = mat[0, 0]
|
||||
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
|
||||
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
|
||||
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
|
||||
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
|
||||
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
|
||||
sample['pitch_feat'] = f0
|
||||
yield sample
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
'''
|
||||
def subsequent_mask(
|
||||
size: int,
|
||||
@@ -230,6 +231,10 @@ def add_optional_chunk_mask(xs: torch.Tensor,
|
||||
chunk_masks = masks & chunk_masks # (B, L, L)
|
||||
else:
|
||||
chunk_masks = masks
|
||||
assert chunk_masks.dtype == torch.bool
|
||||
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
||||
logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
||||
chunk_masks[chunk_masks.sum(dim=-1)==0] = True
|
||||
return chunk_masks
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user