remove flow_cache

This commit is contained in:
lyuxiang.lx
2025-05-23 12:50:47 +08:00
parent 88f467a8ac
commit 68100c267a
14 changed files with 365 additions and 955 deletions

View File

@@ -126,7 +126,7 @@ import torchaudio
**CosyVoice2 Usage** **CosyVoice2 Usage**
```python ```python
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False) cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
# zero_shot usage # zero_shot usage

View File

@@ -61,8 +61,7 @@ def main():
model = CosyVoice(args.model_dir) model = CosyVoice(args.model_dir)
except Exception: except Exception:
try: try:
# NOTE set use_flow_cache=True when export jit for cache inference model = CosyVoice2(args.model_dir)
model = CosyVoice2(args.model_dir, use_flow_cache=True)
except Exception: except Exception:
raise TypeError('no valid model_type!') raise TypeError('no valid model_type!')
@@ -93,9 +92,9 @@ def main():
else: else:
# 3. export flow encoder # 3. export flow encoder
flow_encoder = model.model.flow.encoder flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder, ['forward_chunk']) script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half(), ['forward_chunk']) script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder') logging.info('successfully export flow_encoder')

View File

@@ -62,135 +62,58 @@ def main():
model = CosyVoice(args.model_dir) model = CosyVoice(args.model_dir)
except Exception: except Exception:
try: try:
# NOTE set use_flow_cache=True when export jit for cache inference model = CosyVoice2(args.model_dir)
model = CosyVoice2(args.model_dir, use_flow_cache=True)
except Exception: except Exception:
raise TypeError('no valid model_type!') raise TypeError('no valid model_type!')
if not isinstance(model, CosyVoice2): # 1. export flow decoder estimator
# 1. export flow decoder estimator estimator = model.model.flow.decoder.estimator
estimator = model.model.flow.decoder.estimator estimator.eval()
estimator.eval()
device = model.model.device device = model.model.device
batch_size, seq_len = 2, 256 batch_size, seq_len = 2, 256
out_channels = model.model.flow.decoder.estimator.out_channels out_channels = model.model.flow.decoder.estimator.out_channels
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
torch.onnx.export( torch.onnx.export(
estimator, estimator,
(x, mask, mu, t, spks, cond), (x, mask, mu, t, spks, cond),
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
export_params=True, export_params=True,
opset_version=18, opset_version=18,
do_constant_folding=True, do_constant_folding=True,
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
output_names=['estimator_out'], output_names=['estimator_out'],
dynamic_axes={ dynamic_axes={
'x': {2: 'seq_len'}, 'x': {2: 'seq_len'},
'mask': {2: 'seq_len'}, 'mask': {2: 'seq_len'},
'mu': {2: 'seq_len'}, 'mu': {2: 'seq_len'},
'cond': {2: 'seq_len'}, 'cond': {2: 'seq_len'},
'estimator_out': {2: 'seq_len'}, 'estimator_out': {2: 'seq_len'},
} }
) )
# 2. test computation consistency # 2. test computation consistency
option = onnxruntime.SessionOptions() option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1 option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
sess_options=option, providers=providers) sess_options=option, providers=providers)
for _ in tqdm(range(10)): for _ in tqdm(range(10)):
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
output_pytorch = estimator(x, mask, mu, t, spks, cond) output_pytorch = estimator(x, mask, mu, t, spks, cond)
ort_inputs = { ort_inputs = {
'x': x.cpu().numpy(), 'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(), 'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(), 'mu': mu.cpu().numpy(),
't': t.cpu().numpy(), 't': t.cpu().numpy(),
'spks': spks.cpu().numpy(), 'spks': spks.cpu().numpy(),
'cond': cond.cpu().numpy() 'cond': cond.cpu().numpy()
} }
output_onnx = estimator_onnx.run(None, ort_inputs)[0] output_onnx = estimator_onnx.run(None, ort_inputs)[0]
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
logging.info('successfully export estimator') logging.info('successfully export estimator')
else:
# 1. export flow decoder estimator
estimator = model.model.flow.decoder.estimator
estimator.forward = estimator.forward_chunk
estimator.eval()
device = model.model.device
batch_size, seq_len = 2, 256
out_channels = model.model.flow.decoder.estimator.out_channels
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
cache = model.model.init_flow_cache()['decoder_cache']
cache.pop('offset')
cache = {k: v[0] for k, v in cache.items()}
torch.onnx.export(
estimator,
(x, mask, mu, t, spks, cond,
cache['down_blocks_conv_cache'],
cache['down_blocks_kv_cache'],
cache['mid_blocks_conv_cache'],
cache['mid_blocks_kv_cache'],
cache['up_blocks_conv_cache'],
cache['up_blocks_kv_cache'],
cache['final_blocks_conv_cache']),
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
dynamic_axes={
'x': {2: 'seq_len'},
'mask': {2: 'seq_len'},
'mu': {2: 'seq_len'},
'cond': {2: 'seq_len'},
'down_blocks_kv_cache': {3: 'cache_in_len'},
'mid_blocks_kv_cache': {3: 'cache_in_len'},
'up_blocks_kv_cache': {3: 'cache_in_len'},
'estimator_out': {2: 'seq_len'},
'down_blocks_kv_cache_out': {3: 'cache_out_len'},
'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
'up_blocks_kv_cache_out': {3: 'cache_out_len'},
}
)
# 2. test computation consistency
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
sess_options=option, providers=providers)
for iter in tqdm(range(10)):
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
cache = model.model.init_flow_cache()['decoder_cache']
cache.pop('offset')
cache = {k: v[0] for k, v in cache.items()}
output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()})
ort_inputs = {
'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(),
't': t.cpu().numpy(),
'spks': spks.cpu().numpy(),
'cond': cond.cpu().numpy(),
}
output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}})
if iter == 0:
# NOTE why can not pass first iteration check?
continue
for i, j in zip(output_pytorch, output_onnx):
torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4)
logging.info('successfully export estimator')
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -140,7 +140,7 @@ class CosyVoice:
class CosyVoice2(CosyVoice): class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False, trt_concurrent=1): def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
self.instruct = True if '-Instruct' in model_dir else False self.instruct = True if '-Instruct' in model_dir else False
self.model_dir = model_dir self.model_dir = model_dir
self.fp16 = fp16 self.fp16 = fp16
@@ -162,9 +162,9 @@ class CosyVoice2(CosyVoice):
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache, trt_concurrent) self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
self.model.load('{}/llm.pt'.format(model_dir), self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir), '{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir)) '{}/hift.pt'.format(model_dir))
if load_jit: if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32')) self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))

View File

@@ -33,12 +33,14 @@ class CosyVoiceModel:
llm: torch.nn.Module, llm: torch.nn.Module,
flow: torch.nn.Module, flow: torch.nn.Module,
hift: torch.nn.Module, hift: torch.nn.Module,
fp16: bool = False): fp16: bool = False,
trt_concurrent: int = 1):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm self.llm = llm
self.flow = flow self.flow = flow
self.hift = hift self.hift = hift
self.fp16 = fp16 self.fp16 = fp16
self.trt_concurrent = trt_concurrent
if self.fp16 is True: if self.fp16 is True:
self.llm.half() self.llm.half()
self.flow.half() self.flow.half()
@@ -85,23 +87,18 @@ class CosyVoiceModel:
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!' assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model): if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
if os.path.getsize(flow_decoder_estimator_model) == 0:
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
del self.flow.decoder.estimator del self.flow.decoder.estimator
import tensorrt as trt import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f: with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model) assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
if isinstance(self, CosyVoice2Model): self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
else:
self.flow.decoder.estimator = estimator_engine.create_execution_context()
def get_trt_kwargs(self): def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)] opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"] input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
@@ -249,21 +246,21 @@ class CosyVoice2Model(CosyVoiceModel):
flow: torch.nn.Module, flow: torch.nn.Module,
hift: torch.nn.Module, hift: torch.nn.Module,
fp16: bool = False, fp16: bool = False,
use_flow_cache: bool = False,
trt_concurrent: int = 1): trt_concurrent: int = 1):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm self.llm = llm
self.flow = flow self.flow = flow
# NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
self.flow.encoder.streaming = True
self.flow.decoder.estimator.streaming = True
self.hift = hift self.hift = hift
self.fp16 = fp16 self.fp16 = fp16
self.use_flow_cache = use_flow_cache
self.trt_concurrent = trt_concurrent self.trt_concurrent = trt_concurrent
if self.fp16 is True: if self.fp16 is True:
self.llm.half() self.llm.half()
self.flow.half() self.flow.half()
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml # NOTE must matching training static_chunk_size
self.token_hop_len = 25 self.token_hop_len = 25
self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
# hift cache # hift cache
self.mel_cache_len = 8 self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480) self.source_cache_len = int(self.mel_cache_len * 480)
@@ -278,56 +275,24 @@ class CosyVoice2Model(CosyVoiceModel):
# dict used to store session related variable # dict used to store session related variable
self.tts_speech_token_dict = {} self.tts_speech_token_dict = {}
self.llm_end_dict = {} self.llm_end_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {} self.hift_cache_dict = {}
self.trt_context_dict = {} self.trt_context_dict = {}
def init_flow_cache(self):
encoder_cache = {'offset': 0,
'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
'upsample_offset': 0,
'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
decoder_cache = {'offset': 0,
'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
if self.fp16 is True:
for cache in [encoder_cache, decoder_cache]:
for k, v in cache.items():
if isinstance(v, torch.Tensor):
cache[k] = v.half()
cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
return cache
def load_jit(self, flow_encoder_model): def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder self.flow.encoder = flow_encoder
def get_trt_kwargs(self): def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape"
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]: with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device), tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device), prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device), prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device), embedding=embedding.to(self.device),
cache=self.flow_cache_dict[uuid], finalize=finalize)
finalize=finalize) tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache # append hift cache
if self.hift_cache_dict[uuid] is not None: if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
@@ -362,7 +327,6 @@ class CosyVoice2Model(CosyVoiceModel):
with self.lock: with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None self.hift_cache_dict[this_uuid] = None
self.flow_cache_dict[this_uuid] = self.init_flow_cache()
self.trt_context_dict[this_uuid] = self.trt_context_pool.get() self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
if source_speech_token.shape[1] == 0: if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
@@ -370,27 +334,23 @@ class CosyVoice2Model(CosyVoiceModel):
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid)) p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
p.start() p.start()
if stream is True: if stream is True:
assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM" token_offset = 0
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):]
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:]
while True: while True:
time.sleep(0.1) time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len: this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token, this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token, prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat, prompt_feat=prompt_speech_feat,
embedding=flow_embedding, embedding=flow_embedding,
token_offset=token_offset,
uuid=this_uuid, uuid=this_uuid,
finalize=False) finalize=False)
# NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk token_offset += this_token_hop_len
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
yield {'tts_speech': this_tts_speech.cpu()} yield {'tts_speech': this_tts_speech.cpu()}
with self.lock: if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
break break
p.join() p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
@@ -399,18 +359,19 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_token=flow_prompt_speech_token, prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat, prompt_feat=prompt_speech_feat,
embedding=flow_embedding, embedding=flow_embedding,
token_offset=token_offset,
uuid=this_uuid, uuid=this_uuid,
finalize=True) finalize=True)
yield {'tts_speech': this_tts_speech.cpu()} yield {'tts_speech': this_tts_speech.cpu()}
else: else:
# deal with all tokens # deal with all tokens
assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
p.join() p.join()
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token, this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token, prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat, prompt_feat=prompt_speech_feat,
embedding=flow_embedding, embedding=flow_embedding,
token_offset=0,
uuid=this_uuid, uuid=this_uuid,
finalize=True, finalize=True,
speed=speed) speed=speed)
@@ -419,7 +380,6 @@ class CosyVoice2Model(CosyVoiceModel):
self.tts_speech_token_dict.pop(this_uuid) self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
self.trt_context_pool.put(self.trt_context_dict[this_uuid]) self.trt_context_pool.put(self.trt_context_dict[this_uuid])
self.trt_context_dict.pop(this_uuid) self.trt_context_dict.pop(this_uuid)
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@@ -159,7 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
def compute_fbank(data, def compute_fbank(data,
feat_extractor, feat_extractor,
token_mel_ratio=2, token_mel_ratio=0,
mode='train'): mode='train'):
""" Extract fbank """ Extract fbank
@@ -176,12 +176,11 @@ def compute_fbank(data,
assert 'text_token' in sample assert 'text_token' in sample
waveform = sample['speech'] waveform = sample['speech']
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
if token_mel_ratio != 0:
# trim to align speech_token and speech_feat # trim to align speech_token and speech_feat
token_len = min(feat.shape[0] // token_mel_ratio, sample["speech_token"].shape[0]) token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
feat = feat[:token_mel_ratio * token_len] feat = feat[:token_mel_ratio * token_len]
sample["speech_token"] = sample["speech_token"][:token_len] sample["speech_token"] = sample["speech_token"][:token_len]
sample['speech_feat'] = feat sample['speech_feat'] = feat
yield sample yield sample

View File

@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple, Optional, Dict, Any from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import pack, rearrange, repeat from einops import pack, rearrange, repeat
from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
from cosyvoice.utils.common import mask_to_bias from cosyvoice.utils.common import mask_to_bias
from cosyvoice.utils.mask import add_optional_chunk_mask from cosyvoice.utils.mask import add_optional_chunk_mask
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph from matcha.models.components.transformer import BasicTransformerBlock
class Transpose(torch.nn.Module): class Transpose(torch.nn.Module):
@@ -29,7 +28,7 @@ class Transpose(torch.nn.Module):
self.dim0 = dim0 self.dim0 = dim0
self.dim1 = dim1 self.dim1 = dim1
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.transpose(x, self.dim0, self.dim1) x = torch.transpose(x, self.dim0, self.dim1)
return x return x
@@ -57,15 +56,10 @@ class CausalConv1d(torch.nn.Conv1d):
assert stride == 1 assert stride == 1
self.causal_padding = kernel_size - 1 self.causal_padding = kernel_size - 1
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor) -> torch.Tensor:
if cache.size(2) == 0: x = F.pad(x, (self.causal_padding, 0), value=0.0)
x = F.pad(x, (self.causal_padding, 0), value=0.0)
else:
assert cache.size(2) == self.causal_padding
x = torch.concat([cache, x], dim=2)
cache = x[:, :, -self.causal_padding:]
x = super(CausalConv1d, self).forward(x) x = super(CausalConv1d, self).forward(x)
return x, cache return x
class CausalBlock1D(Block1D): class CausalBlock1D(Block1D):
@@ -79,11 +73,9 @@ class CausalBlock1D(Block1D):
nn.Mish(), nn.Mish(),
) )
def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
output, cache = self.block[0](x * mask, cache) output = self.block(x * mask)
for i in range(1, len(self.block)): return output * mask
output = self.block[i](output)
return output * mask, cache
class CausalResnetBlock1D(ResnetBlock1D): class CausalResnetBlock1D(ResnetBlock1D):
@@ -92,303 +84,6 @@ class CausalResnetBlock1D(ResnetBlock1D):
self.block1 = CausalBlock1D(dim, dim_out) self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out) self.block2 = CausalBlock1D(dim_out, dim_out)
def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
h, block1_cache = self.block1(x, mask, block1_cache)
h += self.mlp(time_emb).unsqueeze(-1)
h, block2_cache = self.block2(h, mask, block2_cache)
output = h + self.res_conv(x * mask)
return output, block1_cache, block2_cache
class CausalAttnProcessor2_0(AttnProcessor2_0):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
super(CausalAttnProcessor2_0, self).__init__()
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
*args,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.Tensor]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
`scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
# NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key_cache = attn.to_k(encoder_hidden_states)
value_cache = attn.to_v(encoder_hidden_states)
# NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
if cache.size(0) != 0:
key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
else:
key, value = key_cache, value_cache
cache = torch.stack([key_cache, value_cache], dim=3)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states, cache
@maybe_allow_in_graph
class CausalAttention(Attention):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor2_0"] = None,
out_dim: int = None,
):
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
_from_deprecated_attn_block, processor, out_dim)
processor = CausalAttnProcessor2_0()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
**cross_attention_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
The forward method of the `Attention` class.
Args:
hidden_states (`torch.Tensor`):
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*):
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.
Returns:
`torch.Tensor`: The output of the attention layer.
"""
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cache=cache,
**cross_attention_kwargs,
)
@maybe_allow_in_graph
class CausalBasicTransformerBlock(BasicTransformerBlock):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
):
super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
cross_attention_dim, activation_fn, num_embeds_ada_norm,
attention_bias, only_cross_attention, double_self_attention,
upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
self.attn1 = CausalAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
) -> Tuple[torch.Tensor, torch.Tensor]:
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output, cache = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
cache=cache,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
{self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states, cache
class ConditionalDecoder(nn.Module): class ConditionalDecoder(nn.Module):
def __init__( def __init__(
@@ -640,7 +335,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [
CausalBasicTransformerBlock( BasicTransformerBlock(
dim=output_channel, dim=output_channel,
num_attention_heads=num_heads, num_attention_heads=num_heads,
attention_head_dim=attention_head_dim, attention_head_dim=attention_head_dim,
@@ -662,7 +357,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [
CausalBasicTransformerBlock( BasicTransformerBlock(
dim=output_channel, dim=output_channel,
num_attention_heads=num_heads, num_attention_heads=num_heads,
attention_head_dim=attention_head_dim, attention_head_dim=attention_head_dim,
@@ -687,7 +382,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
) )
transformer_blocks = nn.ModuleList( transformer_blocks = nn.ModuleList(
[ [
CausalBasicTransformerBlock( BasicTransformerBlock(
dim=output_channel, dim=output_channel,
num_attention_heads=num_heads, num_attention_heads=num_heads,
attention_head_dim=attention_head_dim, attention_head_dim=attention_head_dim,
@@ -724,6 +419,9 @@ class CausalConditionalDecoder(ConditionalDecoder):
Returns: Returns:
_type_: _description_ _type_: _description_
""" """
if hasattr(self, 'streaming'):
assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
streaming = self.streaming
t = self.time_embeddings(t).to(t.dtype) t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t) t = self.time_mlp(t)
@@ -740,36 +438,36 @@ class CausalConditionalDecoder(ConditionalDecoder):
masks = [mask] masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks: for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1] mask_down = masks[-1]
x, _, _ = resnet(x, mask_down, t) x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True: if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
else: else:
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype) attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x, _ = transformer_block( x = transformer_block(
hidden_states=x, hidden_states=x,
attention_mask=attn_mask, attention_mask=attn_mask,
timestep=t, timestep=t,
) )
x = rearrange(x, "b t c -> b c t").contiguous() x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections hiddens.append(x) # Save hidden states for skip connections
x, _ = downsample(x * mask_down) x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2]) masks.append(mask_down[:, :, ::2])
masks = masks[:-1] masks = masks[:-1]
mask_mid = masks[-1] mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks: for resnet, transformer_blocks in self.mid_blocks:
x, _, _ = resnet(x, mask_mid, t) x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True: if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
else: else:
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype) attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x, _ = transformer_block( x = transformer_block(
hidden_states=x, hidden_states=x,
attention_mask=attn_mask, attention_mask=attn_mask,
timestep=t, timestep=t,
@@ -780,124 +478,21 @@ class CausalConditionalDecoder(ConditionalDecoder):
mask_up = masks.pop() mask_up = masks.pop()
skip = hiddens.pop() skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x, _, _ = resnet(x, mask_up, t) x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous() x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True: if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks) attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
else: else:
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1) attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype) attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks: for transformer_block in transformer_blocks:
x, _ = transformer_block( x = transformer_block(
hidden_states=x, hidden_states=x,
attention_mask=attn_mask, attention_mask=attn_mask,
timestep=t, timestep=t,
) )
x = rearrange(x, "b t c -> b c t").contiguous() x = rearrange(x, "b t c -> b c t").contiguous()
x, _ = upsample(x * mask_up) x = upsample(x * mask_up)
x, _ = self.final_block(x, mask_up) x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up) output = self.final_proj(x * mask_up)
return output * mask return output * mask
@torch.inference_mode()
def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
mask_down = masks[-1]
x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
attn_mask = mask_to_bias(attn_mask, x.dtype)
for i, transformer_block in enumerate(transformer_blocks):
x, down_blocks_kv_cache_new[index, i] = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
cache=down_blocks_kv_cache[index, i],
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
attn_mask = mask_to_bias(attn_mask, x.dtype)
for i, transformer_block in enumerate(transformer_blocks):
x, mid_blocks_kv_cache_new[index, i] = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
cache=mid_blocks_kv_cache[index, i]
)
x = rearrange(x, "b t c -> b c t").contiguous()
for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
attn_mask = mask_to_bias(attn_mask, x.dtype)
for i, transformer_block in enumerate(transformer_blocks):
x, up_blocks_kv_cache_new[index, i] = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
cache=up_blocks_kv_cache[index, i]
)
x = rearrange(x, "b t c -> b c t").contiguous()
x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
output = self.final_proj(x * mask_up)
return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache

View File

@@ -241,7 +241,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
prompt_feat, prompt_feat,
prompt_feat_len, prompt_feat_len,
embedding, embedding,
cache,
finalize): finalize):
assert token.shape[0] == 1 assert token.shape[0] == 1
# xvec projection # xvec projection
@@ -255,16 +254,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
# text encode # text encode
if finalize is True: if finalize is True:
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache']) h, h_lengths = self.encoder(token, token_len)
else: else:
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:] token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache']) h, h_lengths = self.encoder(token, token_len, context=context)
cache['encoder_cache']['offset'] = encoder_cache[0]
cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h) h = self.encoder_proj(h)
@@ -274,14 +267,13 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
conds = conds.transpose(1, 2) conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, cache['decoder_cache'] = self.decoder( feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(), mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1), mask=mask.unsqueeze(1),
spks=embedding, spks=embedding,
cond=conds, cond=conds,
n_timesteps=10, n_timesteps=10,
cache=cache['decoder_cache']
) )
feat = feat[:, :, mel_len1:] feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2 assert feat.shape[2] == mel_len2
return feat.float(), cache return feat.float(), None

View File

@@ -126,21 +126,26 @@ class ConditionalCFM(BASECFM):
if isinstance(self.estimator, torch.nn.Module): if isinstance(self.estimator, torch.nn.Module):
return self.estimator(x, mask, mu, t, spks, cond) return self.estimator(x, mask, mu, t, spks, cond)
else: else:
with self.lock: estimator, trt_engine = self.estimator.acquire_estimator()
self.estimator.set_input_shape('x', (2, 80, x.size(2))) estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2))) estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2))) estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,)) estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80)) estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2))) estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine data_ptrs = [x.contiguous().data_ptr(),
assert self.estimator.execute_v2([x.contiguous().data_ptr(), mask.contiguous().data_ptr(),
mask.contiguous().data_ptr(), mu.contiguous().data_ptr(),
mu.contiguous().data_ptr(), t.contiguous().data_ptr(),
t.contiguous().data_ptr(), spks.contiguous().data_ptr(),
spks.contiguous().data_ptr(), cond.contiguous().data_ptr(),
cond.contiguous().data_ptr(), x.data_ptr()]
x.data_ptr()]) is True for i, j in enumerate(data_ptrs):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator)
return x return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
@@ -191,7 +196,7 @@ class CausalConditionalCFM(ConditionalCFM):
self.rand_noise = torch.randn([1, 80, 50 * 300]) self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode() @torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}): def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion """Forward diffusion
Args: Args:
@@ -210,136 +215,9 @@ class CausalConditionalCFM(ConditionalCFM):
shape: (batch_size, n_feats, mel_timesteps) shape: (batch_size, n_feats, mel_timesteps)
""" """
offset = cache.pop('offset') z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
z = z[:, :, offset:]
offset += mu.size(2)
# fix prompt and overlap part mu and z # fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine': if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache) return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
cache['offset'] = offset
return mel, cache
def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
flow_cache_size = cache['down_blocks_kv_cache'].shape[4]
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
mask_in[:] = mask
mu_in[0] = mu
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
cache_step = {k: v[step - 1] for k, v in cache.items()}
dphi_dt, cache_step = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in,
cache_step
)
# NOTE if smaller than flow_cache_size, means last chunk, no need to cache
if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1].float(), cache
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
if isinstance(self.estimator, torch.nn.Module):
x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
else:
estimator, trt_engine = self.estimator.acquire_estimator()
estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mask', (2, 1, x.size(2)))
estimator.set_input_shape('mu', (2, 80, x.size(2)))
estimator.set_input_shape('t', (2,))
estimator.set_input_shape('spks', (2, 80))
estimator.set_input_shape('cond', (2, 80, x.size(2)))
estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
data_ptrs = [x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
cache['down_blocks_conv_cache'].contiguous().data_ptr(),
cache['down_blocks_kv_cache'].contiguous().data_ptr(),
cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
cache['up_blocks_conv_cache'].contiguous().data_ptr(),
cache['up_blocks_kv_cache'].contiguous().data_ptr(),
cache['final_blocks_conv_cache'].contiguous().data_ptr(),
x.data_ptr(),
cache['down_blocks_conv_cache'].data_ptr(),
down_blocks_kv_cache_out.data_ptr(),
cache['mid_blocks_conv_cache'].data_ptr(),
mid_blocks_kv_cache_out.data_ptr(),
cache['up_blocks_conv_cache'].data_ptr(),
up_blocks_kv_cache_out.data_ptr(),
cache['final_blocks_conv_cache'].data_ptr()]
for i, j in enumerate(data_ptrs):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator)
cache = (cache['down_blocks_conv_cache'],
down_blocks_kv_cache_out,
cache['mid_blocks_conv_cache'],
mid_blocks_kv_cache_out,
cache['up_blocks_conv_cache'],
up_blocks_kv_cache_out,
cache['final_blocks_conv_cache'])
return x, cache

View File

@@ -223,6 +223,172 @@ class SourceModuleHnNSF(torch.nn.Module):
return sine_merge, noise, uv return sine_merge, noise, uv
class SineGen2(torch.nn.Module):
""" Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0)
"""
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0,
flag_for_pulse=False):
super(SineGen2, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = self.harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
self.flag_for_pulse = flag_for_pulse
self.upsample_scale = upsample_scale
def _f02uv(self, f0):
# generate uv signal
uv = (f0 > self.voiced_threshold).type(torch.float32)
return uv
def _f02sine(self, f0_values):
""" f0_values: (batchsize, length, dim)
where dim indicates fundamental tone and overtones
"""
# convert to F0 in rad. The interger part n can be ignored
# because 2 * np.pi * n doesn't affect phase
rad_values = (f0_values / self.sampling_rate) % 1
# initial phase noise (no noise for fundamental component)
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
if not self.flag_for_pulse:
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
scale_factor=1 / self.upsample_scale,
mode="linear").transpose(1, 2)
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
sines = torch.sin(phase)
else:
# If necessary, make sure that the first time step of every
# voiced segments is sin(pi) or cos(0)
# This is used for pulse-train generation
# identify the last time step in unvoiced segments
uv = self._f02uv(f0_values)
uv_1 = torch.roll(uv, shifts=-1, dims=1)
uv_1[:, -1, :] = 1
u_loc = (uv < 1) * (uv_1 > 0)
# get the instantanouse phase
tmp_cumsum = torch.cumsum(rad_values, dim=1)
# different batch needs to be processed differently
for idx in range(f0_values.shape[0]):
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
# stores the accumulation of i.phase within
# each voiced segments
tmp_cumsum[idx, :, :] = 0
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
# rad_values - tmp_cumsum: remove the accumulation of i.phase
# within the previous voiced segment.
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
# get the sines
sines = torch.cos(i_phase * 2 * np.pi)
return sines
def forward(self, f0):
""" sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
# fundamental component
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
# generate sine waveforms
sine_waves = self._f02sine(fn) * self.sine_amp
# generate uv signal
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF2(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF2, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
class HiFTGenerator(nn.Module): class HiFTGenerator(nn.Module):
""" """
HiFTNet Generator: Neural Source Filter + ISTFTNet HiFTNet Generator: Neural Source Filter + ISTFTNet
@@ -259,7 +425,9 @@ class HiFTGenerator(nn.Module):
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates) self.num_upsamples = len(upsample_rates)
self.m_source = SourceModuleHnNSF( # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
self.m_source = this_SourceModuleHnNSF(
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics, harmonic_num=nb_harmonics,

View File

@@ -56,16 +56,11 @@ class Upsample1D(nn.Module):
# In this mode, first repeat interpolate, than conv with stride=1 # In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
if conv_cache.size(2) == 0: outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
else:
assert conv_cache.size(2) == self.stride * 2
outputs = torch.concat([conv_cache, outputs], dim=2)
conv_cache_new = outputs[:, :, -self.stride * 2:]
outputs = self.conv(outputs) outputs = self.conv(outputs)
return outputs, input_lengths * self.stride, conv_cache_new return outputs, input_lengths * self.stride
class PreLookaheadLayer(nn.Module): class PreLookaheadLayer(nn.Module):
@@ -83,7 +78,7 @@ class PreLookaheadLayer(nn.Module):
kernel_size=3, stride=1, padding=0, kernel_size=3, stride=1, padding=0,
) )
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
""" """
inputs: (batch_size, seq_len, channels) inputs: (batch_size, seq_len, channels)
""" """
@@ -93,22 +88,18 @@ class PreLookaheadLayer(nn.Module):
if context.size(2) == 0: if context.size(2) == 0:
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
else: else:
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
assert context.size(2) == self.pre_lookahead_len assert context.size(2) == self.pre_lookahead_len
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0) outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
outputs = F.leaky_relu(self.conv1(outputs)) outputs = F.leaky_relu(self.conv1(outputs))
# outputs # outputs
if conv2_cache.size(2) == 0: outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
else:
assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1
outputs = torch.concat([conv2_cache, outputs], dim=2)
conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):]
outputs = self.conv2(outputs) outputs = self.conv2(outputs)
outputs = outputs.transpose(1, 2).contiguous() outputs = outputs.transpose(1, 2).contiguous()
# residual connection # residual connection
outputs = outputs + inputs outputs = outputs + inputs
return outputs, conv2_cache_new return outputs
class UpsampleConformerEncoder(torch.nn.Module): class UpsampleConformerEncoder(torch.nn.Module):
@@ -253,6 +244,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
self, self,
xs: torch.Tensor, xs: torch.Tensor,
xs_lens: torch.Tensor, xs_lens: torch.Tensor,
context: torch.Tensor = torch.zeros(0, 0, 0),
decoding_chunk_size: int = 0, decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1, num_decoding_left_chunks: int = -1,
streaming: bool = False, streaming: bool = False,
@@ -280,20 +272,27 @@ class UpsampleConformerEncoder(torch.nn.Module):
checkpointing API because `__call__` attaches all the hooks of the module. checkpointing API because `__call__` attaches all the hooks of the module.
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
""" """
if hasattr(self, 'streaming'):
assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
streaming = self.streaming
T = xs.size(1) T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None: if self.global_cmvn is not None:
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks) xs, pos_emb, masks = self.embed(xs, masks)
if context.size(1) != 0:
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
context_masks = torch.ones(1, 1, context.size(1)).to(masks)
context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
mask_pad = masks # (B, 1, T/subsample_rate) mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1) chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
# lookahead + conformer encoder # lookahead + conformer encoder
xs, _ = self.pre_lookahead_layer(xs) xs = self.pre_lookahead_layer(xs, context=context)
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
# upsample + conformer encoder # upsample + conformer encoder
xs = xs.transpose(1, 2).contiguous() xs = xs.transpose(1, 2).contiguous()
xs, xs_lens, _ = self.up_layer(xs, xs_lens) xs, xs_lens = self.up_layer(xs, xs_lens)
xs = xs.transpose(1, 2).contiguous() xs = xs.transpose(1, 2).contiguous()
T = xs.size(1) T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
@@ -322,100 +321,3 @@ class UpsampleConformerEncoder(torch.nn.Module):
for layer in self.up_encoders: for layer in self.up_encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
return xs return xs
@torch.jit.export
def forward_chunk(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
offset: int = 0,
context: torch.Tensor = torch.zeros(0, 0, 0),
pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0),
encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0),
upsample_offset: int = 0,
upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0),
upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
NOTE(xcsong):
We pass the `__call__` method of the modules instead of `forward` to the
checkpointing API because `__call__` attaches all the hooks of the module.
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
"""
assert xs.size(0) == 1
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
offset += xs.size(1)
tmp_masks = torch.ones(1,
context.size(1),
device=context.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
if context.size(1) != 0:
context, _, _ = self.embed(context, tmp_masks, offset)
# lookahead + conformer encoder
xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache)
# NOTE in cache mode we do not need to call add_optional_chunk_mask
chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device)
mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
encoders_kv_cache_list = []
for index, layer in enumerate(self.encoders):
xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
encoders_kv_cache_list.append(encoders_kv_cache_new)
encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
# upsample
xs = xs.transpose(1, 2).contiguous()
xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache)
xs = xs.transpose(1, 2).contiguous()
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset)
upsample_offset += xs.size(1)
# conformer encoder
chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device)
mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
upsample_kv_cache_list = []
for index, layer in enumerate(self.up_encoders):
xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index])
upsample_kv_cache_list.append(upsample_kv_cache_new)
upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)

View File

@@ -56,7 +56,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
network = builder.create_network(network_flags) network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger) parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config() config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB
if fp16: if fp16:
config.set_flag(trt.BuilderFlag.FP16) config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile() profile = builder.create_optimization_profile()

View File

@@ -86,7 +86,7 @@ def subsequent_mask(
return mask return mask
def subsequent_chunk_mask( def subsequent_chunk_mask_deprecated(
size: int, size: int,
chunk_size: int, chunk_size: int,
num_left_chunks: int = -1, num_left_chunks: int = -1,
@@ -124,6 +124,40 @@ def subsequent_chunk_mask(
return ret return ret
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
pos_idx = torch.arange(size, device=device)
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
return ret
def add_optional_chunk_mask(xs: torch.Tensor, def add_optional_chunk_mask(xs: torch.Tensor,
masks: torch.Tensor, masks: torch.Tensor,
use_dynamic_chunk: bool, use_dynamic_chunk: bool,
@@ -196,9 +230,6 @@ def add_optional_chunk_mask(xs: torch.Tensor,
else: else:
chunk_masks = masks chunk_masks = masks
assert chunk_masks.dtype == torch.bool assert chunk_masks.dtype == torch.bool
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
return chunk_masks return chunk_masks

View File

@@ -1,37 +0,0 @@
import sys
sys.path.append('third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
import torchaudio # type: ignore
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
# zero_shot usage
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# save zero_shot spk for future usage
assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice.save_spkinfo()
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# instruct usage
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# bistream usage, you can use generator as input, this is useful when using text llm model as input
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
def text_generator():
yield '收到好友从远方寄来的生日礼物,'
yield '那份意外的惊喜与深深的祝福'
yield '让我心中充满了甜蜜的快乐,'
yield '笑容如花儿般绽放。'
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)