mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix white space
This commit is contained in:
@@ -36,7 +36,7 @@ Stage `0` converts raw JSONL files into the parquet format expected by veRL:
|
|||||||
```bash
|
```bash
|
||||||
bash run.sh 0 0
|
bash run.sh 0 0
|
||||||
```
|
```
|
||||||
Create two JSONL files—`train.jsonl` and `test.jsonl`.
|
Create two JSONL files—`train.jsonl` and `test.jsonl`.
|
||||||
The script will then generate two Parquet files:
|
The script will then generate two Parquet files:
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -111,7 +111,7 @@ bash run.sh 5 5
|
|||||||
|
|
||||||
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
|
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
|
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
|
||||||
|
|
||||||
## Results
|
## Results
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
|
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
|
||||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
|
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
|
||||||
python3 pretrained_to_huggingface.py \
|
python3 pretrained_to_huggingface.py \
|
||||||
--pretrained-cosyvoice2-path $model_scope_model_path \
|
--pretrained-cosyvoice2-path $model_scope_model_path \
|
||||||
--save-path $sft_model_path
|
--save-path $sft_model_path
|
||||||
@@ -61,7 +61,7 @@ fi
|
|||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
log "stage 1: start token2wav asr server for reward function"
|
log "stage 1: start token2wav asr server for reward function"
|
||||||
python3 token2wav_asr_server.py --number-of-devices 8
|
python3 token2wav_asr_server.py --number-of-devices 8
|
||||||
fi
|
fi
|
||||||
|
|
||||||
exp_name=official_llm_aishell3_grpo
|
exp_name=official_llm_aishell3_grpo
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
@@ -125,7 +125,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
--backend fsdp \
|
--backend fsdp \
|
||||||
--local_dir $llm_path/actor \
|
--local_dir $llm_path/actor \
|
||||||
--target_dir $llm_path/merged_hf_model || exit 1
|
--target_dir $llm_path/merged_hf_model || exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
log "stage 4: Test the model"
|
log "stage 4: Test the model"
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ class TritonPythonModel:
|
|||||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||||
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||||
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
|
||||||
|
|
||||||
# Create and execute inference request
|
# Create and execute inference request
|
||||||
inference_request = pb_utils.InferenceRequest(
|
inference_request = pb_utils.InferenceRequest(
|
||||||
model_name='token2wav_dit',
|
model_name='token2wav_dit',
|
||||||
@@ -362,7 +362,7 @@ class TritonPythonModel:
|
|||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
|||||||
# Create a SHA-256 hash of the byte string
|
# Create a SHA-256 hash of the byte string
|
||||||
hasher = hashlib.sha256()
|
hasher = hashlib.sha256()
|
||||||
hasher.update(tensor_bytes)
|
hasher.update(tensor_bytes)
|
||||||
|
|
||||||
return hasher.hexdigest()
|
return hasher.hexdigest()
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
@@ -111,9 +111,9 @@ class TritonPythonModel:
|
|||||||
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
||||||
|
|
||||||
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||||
|
|
||||||
request_id = request.request_id()
|
request_id = request.request_id()
|
||||||
|
|
||||||
|
|
||||||
wav_array = pb_utils.get_input_tensor_by_name(
|
wav_array = pb_utils.get_input_tensor_by_name(
|
||||||
request, "reference_wav").as_numpy()
|
request, "reference_wav").as_numpy()
|
||||||
|
|||||||
@@ -133,7 +133,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
option.intra_op_num_threads = 1
|
option.intra_op_num_threads = 1
|
||||||
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
|
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
|
||||||
providers=["CPUExecutionProvider"])
|
providers=["CPUExecutionProvider"])
|
||||||
|
|
||||||
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
|
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
|
||||||
|
|
||||||
gpu="l20"
|
gpu="l20"
|
||||||
@@ -253,7 +252,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
||||||
prompt_speech_tokens_list.append(speech_tokens_i)
|
prompt_speech_tokens_list.append(speech_tokens_i)
|
||||||
return prompt_speech_tokens_list
|
return prompt_speech_tokens_list
|
||||||
|
|
||||||
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
||||||
spk_emb_for_flow = []
|
spk_emb_for_flow = []
|
||||||
for audio in prompt_audios_list:
|
for audio in prompt_audios_list:
|
||||||
@@ -263,11 +262,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
spk_emb = self.forward_spk_embedding(spk_feat)
|
spk_emb = self.forward_spk_embedding(spk_feat)
|
||||||
|
|
||||||
spk_emb_for_flow.append(spk_emb)
|
spk_emb_for_flow.append(spk_emb)
|
||||||
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
||||||
if self.dtype != torch.float32:
|
if self.dtype != torch.float32:
|
||||||
spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
|
spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
|
||||||
return spk_emb_for_flow
|
return spk_emb_for_flow
|
||||||
|
|
||||||
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
||||||
prompt_mels_for_flow = []
|
prompt_mels_for_flow = []
|
||||||
prompt_mels_lens_for_flow = []
|
prompt_mels_lens_for_flow = []
|
||||||
@@ -283,7 +282,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
||||||
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||||
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||||
|
|
||||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
||||||
batch_size = prompt_mels_for_flow.shape[0]
|
batch_size = prompt_mels_for_flow.shape[0]
|
||||||
flow_inputs = []
|
flow_inputs = []
|
||||||
@@ -318,28 +317,24 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||||
):
|
):
|
||||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
|
||||||
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
|
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
|
||||||
|
|
||||||
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||||
|
|
||||||
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||||
|
|
||||||
return generated_wavs
|
return generated_wavs
|
||||||
|
|
||||||
def prepare_prompt_audio(
|
def prepare_prompt_audio(
|
||||||
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||||
):
|
):
|
||||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
|
||||||
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
||||||
|
|
||||||
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
|
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
|
||||||
|
|
||||||
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||||
|
|
||||||
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||||
|
|
||||||
|
|
||||||
@@ -365,7 +360,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_streaming(
|
def forward_streaming(
|
||||||
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
||||||
):
|
):
|
||||||
if speaker_id not in self.speaker_cache:
|
if speaker_id not in self.speaker_cache:
|
||||||
assert prompt_audio is not None, "prompt_audio is required for new speaker"
|
assert prompt_audio is not None, "prompt_audio is required for new speaker"
|
||||||
assert prompt_audio_sample_rate == 16000
|
assert prompt_audio_sample_rate == 16000
|
||||||
@@ -384,7 +379,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
|||||||
if request_id not in self.streaming_flow_cache:
|
if request_id not in self.streaming_flow_cache:
|
||||||
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
||||||
self.hift_cache_dict[request_id] = dict(
|
self.hift_cache_dict[request_id] = dict(
|
||||||
mel = torch.zeros(1, 80, 0, device='cuda'),
|
mel = torch.zeros(1, 80, 0, device='cuda'),
|
||||||
source = torch.zeros(1, 1, 0, device='cuda'),
|
source = torch.zeros(1, 1, 0, device='cuda'),
|
||||||
speech = torch.zeros(1, 0, device='cuda'),
|
speech = torch.zeros(1, 0, device='cuda'),
|
||||||
)
|
)
|
||||||
@@ -445,7 +440,7 @@ def collate_fn(batch):
|
|||||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||||
for i, item in enumerate(batch):
|
for i, item in enumerate(batch):
|
||||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||||
prompt_audios_list.append(audio)
|
prompt_audios_list.append(audio)
|
||||||
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||||
ids.append(item['id'])
|
ids.append(item['id'])
|
||||||
@@ -473,20 +468,20 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
|
|
||||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||||
|
|
||||||
|
|
||||||
for epoch in range(args.warmup):
|
for epoch in range(args.warmup):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
for batch in data_loader:
|
for batch in data_loader:
|
||||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||||
|
|
||||||
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||||
|
|
||||||
|
|
||||||
for id, wav in zip(ids, generated_wavs):
|
for id, wav in zip(ids, generated_wavs):
|
||||||
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
epoch_time = end_time - start_time
|
epoch_time = end_time - start_time
|
||||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||||
@@ -365,7 +365,6 @@ def main(args):
|
|||||||
runner = None
|
runner = None
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||||
|
|
||||||
if 'Step-Audio-2-mini' in args.token2wav_path:
|
if 'Step-Audio-2-mini' in args.token2wav_path:
|
||||||
from token2wav_dit import CosyVoice2_Token2Wav
|
from token2wav_dit import CosyVoice2_Token2Wav
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ def collate_fn(batch):
|
|||||||
prompt_speech_tokens_list, prompt_text_list = [], []
|
prompt_speech_tokens_list, prompt_text_list = [], []
|
||||||
for i, item in enumerate(batch):
|
for i, item in enumerate(batch):
|
||||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||||
prompt_audios_list.append(audio)
|
prompt_audios_list.append(audio)
|
||||||
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||||
ids.append(item['id'])
|
ids.append(item['id'])
|
||||||
@@ -37,7 +37,7 @@ def get_args():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
if not os.path.exists(args.output_dir):
|
if not os.path.exists(args.output_dir):
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ if __name__ == "__main__":
|
|||||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||||
|
|
||||||
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
|
token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
|
||||||
|
|
||||||
CHUNK_SIZE = 25
|
CHUNK_SIZE = 25
|
||||||
token_frame_rate = 25
|
token_frame_rate = 25
|
||||||
OVERLAP_SIZE = 0
|
OVERLAP_SIZE = 0
|
||||||
@@ -68,7 +68,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
semantic_token_ids_arr, token_offset = [], 0
|
semantic_token_ids_arr, token_offset = [], 0
|
||||||
flow_prompt_speech_token_len = len(prompt_speech_tokens)
|
flow_prompt_speech_token_len = len(prompt_speech_tokens)
|
||||||
|
|
||||||
buffer = generated_speech_tokens
|
buffer = generated_speech_tokens
|
||||||
output_wavs = []
|
output_wavs = []
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
@@ -97,7 +97,7 @@ if __name__ == "__main__":
|
|||||||
output_wavs[i] = wav.cpu().numpy().squeeze()
|
output_wavs[i] = wav.cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
|
||||||
audios = output_wavs
|
audios = output_wavs
|
||||||
reconstructed_audio = np.concatenate(audios)
|
reconstructed_audio = np.concatenate(audios)
|
||||||
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
|
sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user