mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix white space
This commit is contained in:
@@ -62,7 +62,7 @@ def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
|
||||
# Create a SHA-256 hash of the byte string
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(tensor_bytes)
|
||||
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
class TritonPythonModel:
|
||||
@@ -111,9 +111,9 @@ class TritonPythonModel:
|
||||
target_speech_tokens = target_speech_tokens.squeeze().tolist()
|
||||
|
||||
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||
|
||||
|
||||
request_id = request.request_id()
|
||||
|
||||
|
||||
|
||||
wav_array = pb_utils.get_input_tensor_by_name(
|
||||
request, "reference_wav").as_numpy()
|
||||
|
||||
@@ -133,7 +133,6 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
option.intra_op_num_threads = 1
|
||||
self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
|
||||
providers=["CPUExecutionProvider"])
|
||||
|
||||
self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
|
||||
|
||||
gpu="l20"
|
||||
@@ -253,7 +252,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
||||
prompt_speech_tokens_list.append(speech_tokens_i)
|
||||
return prompt_speech_tokens_list
|
||||
|
||||
|
||||
def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
|
||||
spk_emb_for_flow = []
|
||||
for audio in prompt_audios_list:
|
||||
@@ -263,11 +262,11 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
spk_emb = self.forward_spk_embedding(spk_feat)
|
||||
|
||||
spk_emb_for_flow.append(spk_emb)
|
||||
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
||||
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
||||
if self.dtype != torch.float32:
|
||||
spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
|
||||
return spk_emb_for_flow
|
||||
|
||||
|
||||
def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
|
||||
prompt_mels_for_flow = []
|
||||
prompt_mels_lens_for_flow = []
|
||||
@@ -283,7 +282,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
||||
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
||||
return prompt_mels_for_flow, prompt_mels_lens_for_flow
|
||||
|
||||
|
||||
def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
|
||||
batch_size = prompt_mels_for_flow.shape[0]
|
||||
flow_inputs = []
|
||||
@@ -318,28 +317,24 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
):
|
||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||
|
||||
|
||||
prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
|
||||
|
||||
generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
|
||||
|
||||
return generated_wavs
|
||||
|
||||
def prepare_prompt_audio(
|
||||
self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
|
||||
):
|
||||
assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
|
||||
|
||||
|
||||
prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
|
||||
|
||||
prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
|
||||
|
||||
return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
|
||||
|
||||
|
||||
@@ -365,7 +360,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
@torch.inference_mode()
|
||||
def forward_streaming(
|
||||
self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
|
||||
):
|
||||
):
|
||||
if speaker_id not in self.speaker_cache:
|
||||
assert prompt_audio is not None, "prompt_audio is required for new speaker"
|
||||
assert prompt_audio_sample_rate == 16000
|
||||
@@ -384,7 +379,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
|
||||
if request_id not in self.streaming_flow_cache:
|
||||
self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
|
||||
self.hift_cache_dict[request_id] = dict(
|
||||
mel = torch.zeros(1, 80, 0, device='cuda'),
|
||||
mel = torch.zeros(1, 80, 0, device='cuda'),
|
||||
source = torch.zeros(1, 1, 0, device='cuda'),
|
||||
speech = torch.zeros(1, 0, device='cuda'),
|
||||
)
|
||||
@@ -445,7 +440,7 @@ def collate_fn(batch):
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
|
||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||
audio = torch.from_numpy(item['prompt_audio']['array']).float()
|
||||
prompt_audios_list.append(audio)
|
||||
prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
|
||||
ids.append(item['id'])
|
||||
@@ -473,20 +468,20 @@ if __name__ == "__main__":
|
||||
|
||||
|
||||
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
|
||||
|
||||
|
||||
|
||||
|
||||
for epoch in range(args.warmup):
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
for batch in data_loader:
|
||||
ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
|
||||
|
||||
generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
|
||||
|
||||
|
||||
|
||||
for id, wav in zip(ids, generated_wavs):
|
||||
torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
epoch_time = end_time - start_time
|
||||
print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")
|
||||
Reference in New Issue
Block a user