diff --git a/src/silero_vad/utils_vad.py b/src/silero_vad/utils_vad.py index 3fd5635..bd91441 100644 --- a/src/silero_vad/utils_vad.py +++ b/src/silero_vad/utils_vad.py @@ -490,18 +490,104 @@ class VADIterator: def collect_chunks(tss: List[dict], - wav: torch.Tensor): - chunks = [] - for i in tss: - chunks.append(wav[i['start']: i['end']]) + wav: torch.Tensor, + seconds: bool = False, + sampling_rate: int = None) -> torch.Tensor: + """Collect audio chunks from a longer audio clip + + This method extracts audio chunks from an audio clip, using a list of + provided coordinates, and concatenates them together. Coordinates can be + passed either as sample numbers or in seconds, in which case the audio + sampling rate is also needed. + + Parameters + ---------- + tss: List[dict] + Coordinate list of the clips to collect from the audio. + wav: torch.Tensor, one dimensional + One dimensional float torch.Tensor, containing the audio to clip. + seconds: bool (default - False) + Whether input coordinates are passed as seconds or samples. + sampling_rate: int (default - None) + Input audio sampling rate. Required if seconds is True. + + Returns + ------- + torch.Tensor, one dimensional + One dimensional float torch.Tensor of the concatenated clipped audio + chunks. + + Raises + ------ + ValueError + Raised if sampling_rate is not provided when seconds is True. + + """ + if seconds and not sampling_rate: + raise ValueError('sampling_rate must be provided when seconds is True') + + chunks = list() + _tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss + + for i in _tss: + chunks.append(wav[i['start']:i['end']]) + return torch.cat(chunks) def drop_chunks(tss: List[dict], - wav: torch.Tensor): - chunks = [] + wav: torch.Tensor, + seconds: bool = False, + sampling_rate: int = None) -> torch.Tensor: + """Drop audio chunks from a longer audio clip + + This method extracts audio chunks from an audio clip, using a list of + provided coordinates, and drops them. Coordinates can be passed either as + sample numbers or in seconds, in which case the audio sampling rate is also + needed. + + Parameters + ---------- + tss: List[dict] + Coordinate list of the clips to drop from from the audio. + wav: torch.Tensor, one dimensional + One dimensional float torch.Tensor, containing the audio to clip. + seconds: bool (default - False) + Whether input coordinates are passed as seconds or samples. + sampling_rate: int (default - None) + Input audio sampling rate. Required if seconds is True. + + Returns + ------- + torch.Tensor, one dimensional + One dimensional float torch.Tensor of the input audio minus the dropped + chunks. + + Raises + ------ + ValueError + Raised if sampling_rate is not provided when seconds is True. + + """ + if seconds and not sampling_rate: + raise ValueError('sampling_rate must be provided when seconds is True') + + chunks = list() cur_start = 0 - for i in tss: + + _tss = _seconds_to_samples_tss(tss, sampling_rate) if seconds else tss + + for i in _tss: chunks.append((wav[cur_start: i['start']])) cur_start = i['end'] + return torch.cat(chunks) + + +def _seconds_to_samples_tss(tss: List[dict], sampling_rate: int) -> List[dict]: + """Convert coordinates expressed in seconds to sample coordinates. + """ + return [{ + 'start': round(crd['start']) * sampling_rate, + 'end': round(crd['end']) * sampling_rate + } for crd in tss]