mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
support streaming tts
This commit is contained in:
@@ -395,38 +395,45 @@ def run_sync_streaming_inference(
|
|||||||
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
|
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
|
||||||
actual_duration = 0
|
actual_duration = 0
|
||||||
if audios:
|
if audios:
|
||||||
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
|
# Only spark_tts model uses cross-fade
|
||||||
fade_out = np.linspace(1, 0, cross_fade_samples)
|
if model_name == "spark_tts":
|
||||||
fade_in = np.linspace(0, 1, cross_fade_samples)
|
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
|
||||||
reconstructed_audio = None
|
fade_out = np.linspace(1, 0, cross_fade_samples)
|
||||||
|
fade_in = np.linspace(0, 1, cross_fade_samples)
|
||||||
|
reconstructed_audio = None
|
||||||
|
|
||||||
# Simplified reconstruction based on client_grpc_streaming.py
|
# Simplified reconstruction based on client_grpc_streaming.py
|
||||||
if not audios:
|
if not audios:
|
||||||
print("Warning: No audio chunks received.")
|
print("Warning: No audio chunks received.")
|
||||||
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
||||||
elif len(audios) == 1:
|
elif len(audios) == 1:
|
||||||
reconstructed_audio = audios[0]
|
reconstructed_audio = audios[0]
|
||||||
|
else:
|
||||||
|
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
||||||
|
for i in range(1, len(audios)):
|
||||||
|
# Cross-fade section
|
||||||
|
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
||||||
|
audios[i - 1][-cross_fade_samples:] * fade_out)
|
||||||
|
# Middle section of the current chunk
|
||||||
|
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
||||||
|
# Concatenate
|
||||||
|
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
||||||
|
# Add the last part of the final chunk
|
||||||
|
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
||||||
|
|
||||||
|
if reconstructed_audio is not None and reconstructed_audio.size > 0:
|
||||||
|
actual_duration = len(reconstructed_audio) / save_sample_rate
|
||||||
|
# Save reconstructed audio
|
||||||
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
||||||
|
else:
|
||||||
|
print("Warning: No audio chunks received or reconstructed.")
|
||||||
|
actual_duration = 0 # Set duration to 0 if no audio
|
||||||
else:
|
else:
|
||||||
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
reconstructed_audio = np.concatenate(audios)
|
||||||
for i in range(1, len(audios)):
|
print(f"reconstructed_audio: {reconstructed_audio.shape}")
|
||||||
# Cross-fade section
|
|
||||||
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
|
||||||
audios[i - 1][-cross_fade_samples:] * fade_out)
|
|
||||||
# Middle section of the current chunk
|
|
||||||
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
|
||||||
# Concatenate
|
|
||||||
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
|
||||||
# Add the last part of the final chunk
|
|
||||||
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
|
||||||
|
|
||||||
if reconstructed_audio is not None and reconstructed_audio.size > 0:
|
|
||||||
actual_duration = len(reconstructed_audio) / save_sample_rate
|
actual_duration = len(reconstructed_audio) / save_sample_rate
|
||||||
# Save reconstructed audio
|
# Save reconstructed audio
|
||||||
os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
|
|
||||||
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
||||||
else:
|
|
||||||
print("Warning: No audio chunks received or reconstructed.")
|
|
||||||
actual_duration = 0 # Set duration to 0 if no audio
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("Warning: No audio chunks received.")
|
print("Warning: No audio chunks received.")
|
||||||
@@ -667,6 +674,7 @@ async def main():
|
|||||||
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
||||||
|
|
||||||
os.makedirs(args.log_dir, exist_ok=True)
|
os.makedirs(args.log_dir, exist_ok=True)
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(num_tasks):
|
for i in range(num_tasks):
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ class TritonPythonModel:
|
|||||||
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
||||||
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
||||||
"temperature": np.array([[0.8]], dtype=np.float32),
|
"temperature": np.array([[0.8]], dtype=np.float32),
|
||||||
|
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
||||||
}
|
}
|
||||||
@@ -144,6 +145,7 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
# Get actual output IDs up to the sequence length
|
# Get actual output IDs up to the sequence length
|
||||||
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
||||||
|
print(f"actual_output_ids: {actual_output_ids}")
|
||||||
|
|
||||||
yield actual_output_ids
|
yield actual_output_ids
|
||||||
else:
|
else:
|
||||||
@@ -193,7 +195,10 @@ class TritonPythonModel:
|
|||||||
prompt_speech_tokens: torch.Tensor,
|
prompt_speech_tokens: torch.Tensor,
|
||||||
prompt_speech_feat: torch.Tensor,
|
prompt_speech_feat: torch.Tensor,
|
||||||
prompt_spk_embedding: torch.Tensor,
|
prompt_spk_embedding: torch.Tensor,
|
||||||
target_speech_tokens: torch.Tensor) -> torch.Tensor:
|
target_speech_tokens: torch.Tensor,
|
||||||
|
request_id: str,
|
||||||
|
token_offset: int = None,
|
||||||
|
finalize: bool = None) -> torch.Tensor:
|
||||||
"""Forward pass through the vocoder component.
|
"""Forward pass through the vocoder component.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -210,11 +215,22 @@ class TritonPythonModel:
|
|||||||
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
|
||||||
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
|
||||||
|
|
||||||
|
inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
||||||
|
|
||||||
|
if token_offset is not None:
|
||||||
|
assert finalize is not None
|
||||||
|
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
|
||||||
|
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
|
||||||
|
inputs_tensor.append(token_offset_tensor)
|
||||||
|
inputs_tensor.append(finalize_tensor)
|
||||||
|
|
||||||
|
|
||||||
# Create and execute inference request
|
# Create and execute inference request
|
||||||
inference_request = pb_utils.InferenceRequest(
|
inference_request = pb_utils.InferenceRequest(
|
||||||
model_name='token2wav',
|
model_name='token2wav',
|
||||||
requested_output_names=['waveform'],
|
requested_output_names=['waveform'],
|
||||||
inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
|
inputs=inputs_tensor,
|
||||||
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_response = inference_request.exec()
|
inference_response = inference_request.exec()
|
||||||
@@ -275,6 +291,7 @@ class TritonPythonModel:
|
|||||||
responses = []
|
responses = []
|
||||||
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
|
request_id = request.request_id()
|
||||||
# Extract input tensors
|
# Extract input tensors
|
||||||
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||||
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||||
@@ -292,6 +309,11 @@ class TritonPythonModel:
|
|||||||
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
|
||||||
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
|
||||||
|
token_hop_len = 25
|
||||||
|
flow_pre_lookahead_len = 3
|
||||||
|
|
||||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||||
reference_text = reference_text[0][0].decode('utf-8')
|
reference_text = reference_text[0][0].decode('utf-8')
|
||||||
|
|
||||||
@@ -308,24 +330,46 @@ class TritonPythonModel:
|
|||||||
# Generate semantic tokens with LLM
|
# Generate semantic tokens with LLM
|
||||||
generated_ids_iter = self.forward_llm(input_ids)
|
generated_ids_iter = self.forward_llm(input_ids)
|
||||||
|
|
||||||
|
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
||||||
|
print(f"here2")
|
||||||
if self.decoupled:
|
if self.decoupled:
|
||||||
response_sender = request.get_response_sender()
|
response_sender = request.get_response_sender()
|
||||||
request_id = request.request_id()
|
|
||||||
generated_ids = []
|
|
||||||
for generated_id in generated_ids_iter:
|
|
||||||
# convert the numpy array into a int32 tensor
|
|
||||||
generated_id = generated_id.tolist()
|
|
||||||
if len(generated_id) > 0:
|
|
||||||
assert len(generated_id) == 1, "Generated ID is not a single integer"
|
|
||||||
generated_ids.append(generated_id[0])
|
|
||||||
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
|
|
||||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
|
||||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
|
||||||
|
|
||||||
# Prepare response
|
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
|
||||||
|
semantic_token_ids_arr, token_offset = [], 0
|
||||||
|
for generated_ids in generated_ids_iter:
|
||||||
|
|
||||||
|
generated_ids = generated_ids.tolist()
|
||||||
|
print(f"generated_id: {generated_ids}")
|
||||||
|
semantic_token_ids_arr.extend(generated_ids)
|
||||||
|
|
||||||
|
prompt_token_pad = int(np.ceil(flow_prompt_speech_token_len / token_hop_len) * token_hop_len - flow_prompt_speech_token_len)
|
||||||
|
this_token_hop_len = token_hop_len + prompt_token_pad if token_offset == 0 else token_hop_len
|
||||||
|
print(f"this_token_hop_len: {this_token_hop_len}")
|
||||||
|
if len(semantic_token_ids_arr) - token_offset >= this_token_hop_len + flow_pre_lookahead_len:
|
||||||
|
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + flow_pre_lookahead_len]
|
||||||
|
print(f"this_tts_speech_token: {this_tts_speech_token}")
|
||||||
|
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
|
print(f"here3")
|
||||||
|
|
||||||
|
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, False)
|
||||||
|
print(f"here4")
|
||||||
|
# Prepare response to send
|
||||||
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
|
response_sender.send(inference_response)
|
||||||
|
|
||||||
|
self.logger.log_info(f"[{request_id}]")
|
||||||
|
token_offset += this_token_hop_len
|
||||||
|
print(f"here")
|
||||||
|
|
||||||
|
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
|
||||||
|
sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
|
||||||
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
||||||
response_sender.send(inference_response)
|
response_sender.send(inference_response)
|
||||||
|
|
||||||
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
||||||
self.logger.log_info("send tritonserver_response_complete_final to end")
|
self.logger.log_info("send tritonserver_response_complete_final to end")
|
||||||
else:
|
else:
|
||||||
@@ -334,8 +378,7 @@ class TritonPythonModel:
|
|||||||
if generated_ids is None or len(generated_ids) == 0:
|
if generated_ids is None or len(generated_ids) == 0:
|
||||||
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
||||||
|
|
||||||
prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
|
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
|
||||||
audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
|
|
||||||
|
|
||||||
# Prepare response
|
# Prepare response
|
||||||
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||||
|
|||||||
@@ -32,12 +32,16 @@ from typing import List, Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.dlpack import to_dlpack
|
from torch.utils.dlpack import to_dlpack
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
import triton_python_backend_utils as pb_utils
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
from hyperpyyaml import load_hyperpyyaml
|
from hyperpyyaml import load_hyperpyyaml
|
||||||
|
from cosyvoice.utils.common import fade_in_out
|
||||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||||
from cosyvoice.utils.common import TrtContextWrapper
|
from cosyvoice.utils.common import TrtContextWrapper
|
||||||
|
from collections import defaultdict
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -81,6 +85,13 @@ class CosyVoice2Model:
|
|||||||
if self.fp16 is True:
|
if self.fp16 is True:
|
||||||
self.flow.half()
|
self.flow.half()
|
||||||
|
|
||||||
|
# streaming tts config
|
||||||
|
self.token_hop_len = 25
|
||||||
|
self.mel_cache_len = 8
|
||||||
|
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||||
|
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||||
|
self.hift_cache_dict = defaultdict(lambda: None)
|
||||||
|
|
||||||
def load_jit(self, flow_encoder_model):
|
def load_jit(self, flow_encoder_model):
|
||||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
self.flow.encoder = flow_encoder
|
self.flow.encoder = flow_encoder
|
||||||
@@ -112,6 +123,43 @@ class CosyVoice2Model:
|
|||||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
|
|
||||||
|
|
||||||
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||||
|
with torch.cuda.amp.autocast(self.fp16):
|
||||||
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||||
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_token=prompt_token.to(self.device),
|
||||||
|
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_feat=prompt_feat.to(self.device),
|
||||||
|
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=embedding.to(self.device),
|
||||||
|
streaming=stream,
|
||||||
|
finalize=finalize)
|
||||||
|
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||||
|
# append hift cache
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||||
|
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||||
|
else:
|
||||||
|
hift_cache_source = torch.zeros(1, 1, 0)
|
||||||
|
# keep overlap mel and hift cache
|
||||||
|
if finalize is False:
|
||||||
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
|
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||||
|
'source': tts_source[:, :, -self.source_cache_len:],
|
||||||
|
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||||
|
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||||
|
else:
|
||||||
|
if speed != 1.0:
|
||||||
|
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||||
|
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||||
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
|
return tts_speech
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
"""Triton Python model for vocoder.
|
"""Triton Python model for vocoder.
|
||||||
|
|
||||||
@@ -166,25 +214,49 @@ class TritonPythonModel:
|
|||||||
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
||||||
|
|
||||||
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
||||||
token=target_speech_tokens,
|
token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
|
||||||
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
if token_offset is not None:
|
||||||
self.device
|
token_offset = token_offset.as_numpy().item()
|
||||||
),
|
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
||||||
prompt_token=prompt_speech_tokens,
|
if not finalize:
|
||||||
prompt_token_len=torch.tensor(
|
stream = True
|
||||||
[prompt_speech_tokens.shape[1]], dtype=torch.int32
|
else:
|
||||||
).to(self.device),
|
stream = False
|
||||||
prompt_feat=prompt_speech_feat,
|
request_id = request.request_id()
|
||||||
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
|
print(f"token_offset: {token_offset}, finalize: {finalize}, request_id: {request_id}")
|
||||||
embedding=prompt_spk_embedding,
|
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
|
||||||
streaming=False,
|
prompt_token=prompt_speech_tokens,
|
||||||
finalize=True,
|
prompt_feat=prompt_speech_feat,
|
||||||
)
|
embedding=prompt_spk_embedding,
|
||||||
|
token_offset=token_offset,
|
||||||
|
uuid=request_id,
|
||||||
|
stream=stream,
|
||||||
|
finalize=finalize)
|
||||||
|
if finalize:
|
||||||
|
print(f"dict keys: {self.token2wav_model.model.hift_cache_dict.keys()}")
|
||||||
|
self.token2wav_model.model.hift_cache_dict.pop(request_id)
|
||||||
|
|
||||||
audio_hat, _ = self.token2wav_model.model.hift.inference(
|
else:
|
||||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
tts_mel, _ = self.token2wav_model.model.flow.inference(
|
||||||
)
|
token=target_speech_tokens,
|
||||||
|
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
||||||
|
self.device
|
||||||
|
),
|
||||||
|
prompt_token=prompt_speech_tokens,
|
||||||
|
prompt_token_len=torch.tensor(
|
||||||
|
[prompt_speech_tokens.shape[1]], dtype=torch.int32
|
||||||
|
).to(self.device),
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=prompt_spk_embedding,
|
||||||
|
streaming=False,
|
||||||
|
finalize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_hat, _ = self.token2wav_model.model.hift.inference(
|
||||||
|
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,20 @@ input [
|
|||||||
name: "prompt_spk_embedding"
|
name: "prompt_spk_embedding"
|
||||||
data_type: TYPE_FP16
|
data_type: TYPE_FP16
|
||||||
dims: [-1]
|
dims: [-1]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token_offset"
|
||||||
|
data_type: TYPE_INT32
|
||||||
|
dims: [ 1 ]
|
||||||
|
reshape: { shape: [ ] }
|
||||||
|
optional: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "finalize"
|
||||||
|
data_type: TYPE_BOOL
|
||||||
|
dims: [ 1 ]
|
||||||
|
reshape: { shape: [ ] }
|
||||||
|
optional: true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
output [
|
output [
|
||||||
|
|||||||
Reference in New Issue
Block a user