mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
remove flow_cache
This commit is contained in:
@@ -126,7 +126,7 @@ import torchaudio
|
|||||||
|
|
||||||
**CosyVoice2 Usage**
|
**CosyVoice2 Usage**
|
||||||
```python
|
```python
|
||||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)
|
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
|
||||||
|
|
||||||
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
||||||
# zero_shot usage
|
# zero_shot usage
|
||||||
|
|||||||
@@ -61,8 +61,7 @@ def main():
|
|||||||
model = CosyVoice(args.model_dir)
|
model = CosyVoice(args.model_dir)
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
# NOTE set use_flow_cache=True when export jit for cache inference
|
model = CosyVoice2(args.model_dir)
|
||||||
model = CosyVoice2(args.model_dir, use_flow_cache=True)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise TypeError('no valid model_type!')
|
raise TypeError('no valid model_type!')
|
||||||
|
|
||||||
@@ -93,9 +92,9 @@ def main():
|
|||||||
else:
|
else:
|
||||||
# 3. export flow encoder
|
# 3. export flow encoder
|
||||||
flow_encoder = model.model.flow.encoder
|
flow_encoder = model.model.flow.encoder
|
||||||
script = get_optimized_script(flow_encoder, ['forward_chunk'])
|
script = get_optimized_script(flow_encoder)
|
||||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||||
script = get_optimized_script(flow_encoder.half(), ['forward_chunk'])
|
script = get_optimized_script(flow_encoder.half())
|
||||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||||
logging.info('successfully export flow_encoder')
|
logging.info('successfully export flow_encoder')
|
||||||
|
|
||||||
|
|||||||
@@ -62,135 +62,58 @@ def main():
|
|||||||
model = CosyVoice(args.model_dir)
|
model = CosyVoice(args.model_dir)
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
# NOTE set use_flow_cache=True when export jit for cache inference
|
model = CosyVoice2(args.model_dir)
|
||||||
model = CosyVoice2(args.model_dir, use_flow_cache=True)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise TypeError('no valid model_type!')
|
raise TypeError('no valid model_type!')
|
||||||
|
|
||||||
if not isinstance(model, CosyVoice2):
|
# 1. export flow decoder estimator
|
||||||
# 1. export flow decoder estimator
|
estimator = model.model.flow.decoder.estimator
|
||||||
estimator = model.model.flow.decoder.estimator
|
estimator.eval()
|
||||||
estimator.eval()
|
|
||||||
|
|
||||||
device = model.model.device
|
device = model.model.device
|
||||||
batch_size, seq_len = 2, 256
|
batch_size, seq_len = 2, 256
|
||||||
out_channels = model.model.flow.decoder.estimator.out_channels
|
out_channels = model.model.flow.decoder.estimator.out_channels
|
||||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
estimator,
|
estimator,
|
||||||
(x, mask, mu, t, spks, cond),
|
(x, mask, mu, t, spks, cond),
|
||||||
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||||
export_params=True,
|
export_params=True,
|
||||||
opset_version=18,
|
opset_version=18,
|
||||||
do_constant_folding=True,
|
do_constant_folding=True,
|
||||||
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
||||||
output_names=['estimator_out'],
|
output_names=['estimator_out'],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
'x': {2: 'seq_len'},
|
'x': {2: 'seq_len'},
|
||||||
'mask': {2: 'seq_len'},
|
'mask': {2: 'seq_len'},
|
||||||
'mu': {2: 'seq_len'},
|
'mu': {2: 'seq_len'},
|
||||||
'cond': {2: 'seq_len'},
|
'cond': {2: 'seq_len'},
|
||||||
'estimator_out': {2: 'seq_len'},
|
'estimator_out': {2: 'seq_len'},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. test computation consistency
|
# 2. test computation consistency
|
||||||
option = onnxruntime.SessionOptions()
|
option = onnxruntime.SessionOptions()
|
||||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
option.intra_op_num_threads = 1
|
option.intra_op_num_threads = 1
|
||||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||||
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||||
sess_options=option, providers=providers)
|
sess_options=option, providers=providers)
|
||||||
|
|
||||||
for _ in tqdm(range(10)):
|
for _ in tqdm(range(10)):
|
||||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
||||||
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
||||||
ort_inputs = {
|
ort_inputs = {
|
||||||
'x': x.cpu().numpy(),
|
'x': x.cpu().numpy(),
|
||||||
'mask': mask.cpu().numpy(),
|
'mask': mask.cpu().numpy(),
|
||||||
'mu': mu.cpu().numpy(),
|
'mu': mu.cpu().numpy(),
|
||||||
't': t.cpu().numpy(),
|
't': t.cpu().numpy(),
|
||||||
'spks': spks.cpu().numpy(),
|
'spks': spks.cpu().numpy(),
|
||||||
'cond': cond.cpu().numpy()
|
'cond': cond.cpu().numpy()
|
||||||
}
|
}
|
||||||
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
||||||
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
||||||
logging.info('successfully export estimator')
|
logging.info('successfully export estimator')
|
||||||
else:
|
|
||||||
# 1. export flow decoder estimator
|
|
||||||
estimator = model.model.flow.decoder.estimator
|
|
||||||
estimator.forward = estimator.forward_chunk
|
|
||||||
estimator.eval()
|
|
||||||
|
|
||||||
device = model.model.device
|
|
||||||
batch_size, seq_len = 2, 256
|
|
||||||
out_channels = model.model.flow.decoder.estimator.out_channels
|
|
||||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
|
||||||
cache = model.model.init_flow_cache()['decoder_cache']
|
|
||||||
cache.pop('offset')
|
|
||||||
cache = {k: v[0] for k, v in cache.items()}
|
|
||||||
torch.onnx.export(
|
|
||||||
estimator,
|
|
||||||
(x, mask, mu, t, spks, cond,
|
|
||||||
cache['down_blocks_conv_cache'],
|
|
||||||
cache['down_blocks_kv_cache'],
|
|
||||||
cache['mid_blocks_conv_cache'],
|
|
||||||
cache['mid_blocks_kv_cache'],
|
|
||||||
cache['up_blocks_conv_cache'],
|
|
||||||
cache['up_blocks_kv_cache'],
|
|
||||||
cache['final_blocks_conv_cache']),
|
|
||||||
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
||||||
export_params=True,
|
|
||||||
opset_version=18,
|
|
||||||
do_constant_folding=True,
|
|
||||||
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
|
|
||||||
'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
|
|
||||||
output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
|
|
||||||
'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
|
|
||||||
dynamic_axes={
|
|
||||||
'x': {2: 'seq_len'},
|
|
||||||
'mask': {2: 'seq_len'},
|
|
||||||
'mu': {2: 'seq_len'},
|
|
||||||
'cond': {2: 'seq_len'},
|
|
||||||
'down_blocks_kv_cache': {3: 'cache_in_len'},
|
|
||||||
'mid_blocks_kv_cache': {3: 'cache_in_len'},
|
|
||||||
'up_blocks_kv_cache': {3: 'cache_in_len'},
|
|
||||||
'estimator_out': {2: 'seq_len'},
|
|
||||||
'down_blocks_kv_cache_out': {3: 'cache_out_len'},
|
|
||||||
'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
|
|
||||||
'up_blocks_kv_cache_out': {3: 'cache_out_len'},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. test computation consistency
|
|
||||||
option = onnxruntime.SessionOptions()
|
|
||||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
option.intra_op_num_threads = 1
|
|
||||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
||||||
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
||||||
sess_options=option, providers=providers)
|
|
||||||
|
|
||||||
for iter in tqdm(range(10)):
|
|
||||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
|
||||||
cache = model.model.init_flow_cache()['decoder_cache']
|
|
||||||
cache.pop('offset')
|
|
||||||
cache = {k: v[0] for k, v in cache.items()}
|
|
||||||
output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()})
|
|
||||||
ort_inputs = {
|
|
||||||
'x': x.cpu().numpy(),
|
|
||||||
'mask': mask.cpu().numpy(),
|
|
||||||
'mu': mu.cpu().numpy(),
|
|
||||||
't': t.cpu().numpy(),
|
|
||||||
'spks': spks.cpu().numpy(),
|
|
||||||
'cond': cond.cpu().numpy(),
|
|
||||||
}
|
|
||||||
output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}})
|
|
||||||
if iter == 0:
|
|
||||||
# NOTE why can not pass first iteration check?
|
|
||||||
continue
|
|
||||||
for i, j in zip(output_pytorch, output_onnx):
|
|
||||||
torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4)
|
|
||||||
logging.info('successfully export estimator')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class CosyVoice:
|
|||||||
|
|
||||||
class CosyVoice2(CosyVoice):
|
class CosyVoice2(CosyVoice):
|
||||||
|
|
||||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False, trt_concurrent=1):
|
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||||
self.instruct = True if '-Instruct' in model_dir else False
|
self.instruct = True if '-Instruct' in model_dir else False
|
||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
@@ -162,9 +162,9 @@ class CosyVoice2(CosyVoice):
|
|||||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||||
load_jit, load_trt, fp16 = False, False, False
|
load_jit, load_trt, fp16 = False, False, False
|
||||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache, trt_concurrent)
|
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
|
||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/hift.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
if load_jit:
|
if load_jit:
|
||||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
|
|||||||
@@ -33,12 +33,14 @@ class CosyVoiceModel:
|
|||||||
llm: torch.nn.Module,
|
llm: torch.nn.Module,
|
||||||
flow: torch.nn.Module,
|
flow: torch.nn.Module,
|
||||||
hift: torch.nn.Module,
|
hift: torch.nn.Module,
|
||||||
fp16: bool = False):
|
fp16: bool = False,
|
||||||
|
trt_concurrent: int = 1):
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.hift = hift
|
self.hift = hift
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
|
self.trt_concurrent = trt_concurrent
|
||||||
if self.fp16 is True:
|
if self.fp16 is True:
|
||||||
self.llm.half()
|
self.llm.half()
|
||||||
self.flow.half()
|
self.flow.half()
|
||||||
@@ -85,23 +87,18 @@ class CosyVoiceModel:
|
|||||||
|
|
||||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
||||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||||
if not os.path.exists(flow_decoder_estimator_model):
|
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||||
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
||||||
if os.path.getsize(flow_decoder_estimator_model) == 0:
|
|
||||||
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
|
|
||||||
del self.flow.decoder.estimator
|
del self.flow.decoder.estimator
|
||||||
import tensorrt as trt
|
import tensorrt as trt
|
||||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||||
if isinstance(self, CosyVoice2Model):
|
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
|
||||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
|
|
||||||
else:
|
|
||||||
self.flow.decoder.estimator = estimator_engine.create_execution_context()
|
|
||||||
|
|
||||||
def get_trt_kwargs(self):
|
def get_trt_kwargs(self):
|
||||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||||
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)]
|
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
|
||||||
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
||||||
input_names = ["x", "mask", "mu", "cond"]
|
input_names = ["x", "mask", "mu", "cond"]
|
||||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||||
@@ -249,21 +246,21 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
flow: torch.nn.Module,
|
flow: torch.nn.Module,
|
||||||
hift: torch.nn.Module,
|
hift: torch.nn.Module,
|
||||||
fp16: bool = False,
|
fp16: bool = False,
|
||||||
use_flow_cache: bool = False,
|
|
||||||
trt_concurrent: int = 1):
|
trt_concurrent: int = 1):
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
|
# NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
|
||||||
|
self.flow.encoder.streaming = True
|
||||||
|
self.flow.decoder.estimator.streaming = True
|
||||||
self.hift = hift
|
self.hift = hift
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
self.use_flow_cache = use_flow_cache
|
|
||||||
self.trt_concurrent = trt_concurrent
|
self.trt_concurrent = trt_concurrent
|
||||||
if self.fp16 is True:
|
if self.fp16 is True:
|
||||||
self.llm.half()
|
self.llm.half()
|
||||||
self.flow.half()
|
self.flow.half()
|
||||||
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
|
# NOTE must matching training static_chunk_size
|
||||||
self.token_hop_len = 25
|
self.token_hop_len = 25
|
||||||
self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
|
|
||||||
# hift cache
|
# hift cache
|
||||||
self.mel_cache_len = 8
|
self.mel_cache_len = 8
|
||||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||||
@@ -278,56 +275,24 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
# dict used to store session related variable
|
# dict used to store session related variable
|
||||||
self.tts_speech_token_dict = {}
|
self.tts_speech_token_dict = {}
|
||||||
self.llm_end_dict = {}
|
self.llm_end_dict = {}
|
||||||
self.flow_cache_dict = {}
|
|
||||||
self.hift_cache_dict = {}
|
self.hift_cache_dict = {}
|
||||||
self.trt_context_dict = {}
|
self.trt_context_dict = {}
|
||||||
|
|
||||||
def init_flow_cache(self):
|
|
||||||
encoder_cache = {'offset': 0,
|
|
||||||
'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
|
|
||||||
'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
|
|
||||||
'upsample_offset': 0,
|
|
||||||
'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
|
|
||||||
'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
|
|
||||||
decoder_cache = {'offset': 0,
|
|
||||||
'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
|
|
||||||
'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
|
||||||
'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
|
|
||||||
'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
|
||||||
'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
|
|
||||||
'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
|
||||||
'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
|
|
||||||
if self.fp16 is True:
|
|
||||||
for cache in [encoder_cache, decoder_cache]:
|
|
||||||
for k, v in cache.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
cache[k] = v.half()
|
|
||||||
cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
|
|
||||||
return cache
|
|
||||||
|
|
||||||
def load_jit(self, flow_encoder_model):
|
def load_jit(self, flow_encoder_model):
|
||||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
self.flow.encoder = flow_encoder
|
self.flow.encoder = flow_encoder
|
||||||
|
|
||||||
def get_trt_kwargs(self):
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
|
||||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
|
|
||||||
opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
|
|
||||||
max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
|
|
||||||
input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
|
|
||||||
assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape"
|
|
||||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
|
||||||
|
|
||||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
|
||||||
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
|
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
|
||||||
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_token=prompt_token.to(self.device),
|
prompt_token=prompt_token.to(self.device),
|
||||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
prompt_feat=prompt_feat.to(self.device),
|
prompt_feat=prompt_feat.to(self.device),
|
||||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
embedding=embedding.to(self.device),
|
embedding=embedding.to(self.device),
|
||||||
cache=self.flow_cache_dict[uuid],
|
finalize=finalize)
|
||||||
finalize=finalize)
|
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||||
# append hift cache
|
# append hift cache
|
||||||
if self.hift_cache_dict[uuid] is not None:
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||||
@@ -362,7 +327,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
with self.lock:
|
with self.lock:
|
||||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||||
self.hift_cache_dict[this_uuid] = None
|
self.hift_cache_dict[this_uuid] = None
|
||||||
self.flow_cache_dict[this_uuid] = self.init_flow_cache()
|
|
||||||
self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
|
self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
|
||||||
if source_speech_token.shape[1] == 0:
|
if source_speech_token.shape[1] == 0:
|
||||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||||
@@ -370,27 +334,23 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
|
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
|
||||||
p.start()
|
p.start()
|
||||||
if stream is True:
|
if stream is True:
|
||||||
assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
|
token_offset = 0
|
||||||
# NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
|
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
|
||||||
flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):]
|
|
||||||
prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:]
|
|
||||||
while True:
|
while True:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
|
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
|
||||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
|
||||||
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
||||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
prompt_token=flow_prompt_speech_token,
|
prompt_token=flow_prompt_speech_token,
|
||||||
prompt_feat=prompt_speech_feat,
|
prompt_feat=prompt_speech_feat,
|
||||||
embedding=flow_embedding,
|
embedding=flow_embedding,
|
||||||
|
token_offset=token_offset,
|
||||||
uuid=this_uuid,
|
uuid=this_uuid,
|
||||||
finalize=False)
|
finalize=False)
|
||||||
# NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk
|
token_offset += this_token_hop_len
|
||||||
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
|
|
||||||
prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
|
|
||||||
yield {'tts_speech': this_tts_speech.cpu()}
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
with self.lock:
|
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
|
||||||
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
|
|
||||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
|
|
||||||
break
|
break
|
||||||
p.join()
|
p.join()
|
||||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||||
@@ -399,18 +359,19 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
prompt_token=flow_prompt_speech_token,
|
prompt_token=flow_prompt_speech_token,
|
||||||
prompt_feat=prompt_speech_feat,
|
prompt_feat=prompt_speech_feat,
|
||||||
embedding=flow_embedding,
|
embedding=flow_embedding,
|
||||||
|
token_offset=token_offset,
|
||||||
uuid=this_uuid,
|
uuid=this_uuid,
|
||||||
finalize=True)
|
finalize=True)
|
||||||
yield {'tts_speech': this_tts_speech.cpu()}
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
else:
|
else:
|
||||||
# deal with all tokens
|
# deal with all tokens
|
||||||
assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
|
|
||||||
p.join()
|
p.join()
|
||||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
prompt_token=flow_prompt_speech_token,
|
prompt_token=flow_prompt_speech_token,
|
||||||
prompt_feat=prompt_speech_feat,
|
prompt_feat=prompt_speech_feat,
|
||||||
embedding=flow_embedding,
|
embedding=flow_embedding,
|
||||||
|
token_offset=0,
|
||||||
uuid=this_uuid,
|
uuid=this_uuid,
|
||||||
finalize=True,
|
finalize=True,
|
||||||
speed=speed)
|
speed=speed)
|
||||||
@@ -419,7 +380,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
|||||||
self.tts_speech_token_dict.pop(this_uuid)
|
self.tts_speech_token_dict.pop(this_uuid)
|
||||||
self.llm_end_dict.pop(this_uuid)
|
self.llm_end_dict.pop(this_uuid)
|
||||||
self.hift_cache_dict.pop(this_uuid)
|
self.hift_cache_dict.pop(this_uuid)
|
||||||
self.flow_cache_dict.pop(this_uuid)
|
|
||||||
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
|
self.trt_context_pool.put(self.trt_context_dict[this_uuid])
|
||||||
self.trt_context_dict.pop(this_uuid)
|
self.trt_context_dict.pop(this_uuid)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
|
|||||||
|
|
||||||
def compute_fbank(data,
|
def compute_fbank(data,
|
||||||
feat_extractor,
|
feat_extractor,
|
||||||
token_mel_ratio=2,
|
token_mel_ratio=0,
|
||||||
mode='train'):
|
mode='train'):
|
||||||
""" Extract fbank
|
""" Extract fbank
|
||||||
|
|
||||||
@@ -176,12 +176,11 @@ def compute_fbank(data,
|
|||||||
assert 'text_token' in sample
|
assert 'text_token' in sample
|
||||||
waveform = sample['speech']
|
waveform = sample['speech']
|
||||||
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
||||||
|
if token_mel_ratio != 0:
|
||||||
# trim to align speech_token and speech_feat
|
# trim to align speech_token and speech_feat
|
||||||
token_len = min(feat.shape[0] // token_mel_ratio, sample["speech_token"].shape[0])
|
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
|
||||||
feat = feat[:token_mel_ratio * token_len]
|
feat = feat[:token_mel_ratio * token_len]
|
||||||
sample["speech_token"] = sample["speech_token"][:token_len]
|
sample["speech_token"] = sample["speech_token"][:token_len]
|
||||||
|
|
||||||
sample['speech_feat'] = feat
|
sample['speech_feat'] = feat
|
||||||
yield sample
|
yield sample
|
||||||
|
|
||||||
|
|||||||
@@ -11,16 +11,15 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Tuple, Optional, Dict, Any
|
from typing import Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import pack, rearrange, repeat
|
from einops import pack, rearrange, repeat
|
||||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
|
|
||||||
from cosyvoice.utils.common import mask_to_bias
|
from cosyvoice.utils.common import mask_to_bias
|
||||||
from cosyvoice.utils.mask import add_optional_chunk_mask
|
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||||
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
||||||
from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
|
from matcha.models.components.transformer import BasicTransformerBlock
|
||||||
|
|
||||||
|
|
||||||
class Transpose(torch.nn.Module):
|
class Transpose(torch.nn.Module):
|
||||||
@@ -29,7 +28,7 @@ class Transpose(torch.nn.Module):
|
|||||||
self.dim0 = dim0
|
self.dim0 = dim0
|
||||||
self.dim1 = dim1
|
self.dim1 = dim1
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = torch.transpose(x, self.dim0, self.dim1)
|
x = torch.transpose(x, self.dim0, self.dim1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -57,15 +56,10 @@ class CausalConv1d(torch.nn.Conv1d):
|
|||||||
assert stride == 1
|
assert stride == 1
|
||||||
self.causal_padding = kernel_size - 1
|
self.causal_padding = kernel_size - 1
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if cache.size(2) == 0:
|
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||||
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
|
||||||
else:
|
|
||||||
assert cache.size(2) == self.causal_padding
|
|
||||||
x = torch.concat([cache, x], dim=2)
|
|
||||||
cache = x[:, :, -self.causal_padding:]
|
|
||||||
x = super(CausalConv1d, self).forward(x)
|
x = super(CausalConv1d, self).forward(x)
|
||||||
return x, cache
|
return x
|
||||||
|
|
||||||
|
|
||||||
class CausalBlock1D(Block1D):
|
class CausalBlock1D(Block1D):
|
||||||
@@ -79,11 +73,9 @@ class CausalBlock1D(Block1D):
|
|||||||
nn.Mish(),
|
nn.Mish(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
output, cache = self.block[0](x * mask, cache)
|
output = self.block(x * mask)
|
||||||
for i in range(1, len(self.block)):
|
return output * mask
|
||||||
output = self.block[i](output)
|
|
||||||
return output * mask, cache
|
|
||||||
|
|
||||||
|
|
||||||
class CausalResnetBlock1D(ResnetBlock1D):
|
class CausalResnetBlock1D(ResnetBlock1D):
|
||||||
@@ -92,303 +84,6 @@ class CausalResnetBlock1D(ResnetBlock1D):
|
|||||||
self.block1 = CausalBlock1D(dim, dim_out)
|
self.block1 = CausalBlock1D(dim, dim_out)
|
||||||
self.block2 = CausalBlock1D(dim_out, dim_out)
|
self.block2 = CausalBlock1D(dim_out, dim_out)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
|
|
||||||
block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
h, block1_cache = self.block1(x, mask, block1_cache)
|
|
||||||
h += self.mlp(time_emb).unsqueeze(-1)
|
|
||||||
h, block2_cache = self.block2(h, mask, block2_cache)
|
|
||||||
output = h + self.res_conv(x * mask)
|
|
||||||
return output, block1_cache, block2_cache
|
|
||||||
|
|
||||||
|
|
||||||
class CausalAttnProcessor2_0(AttnProcessor2_0):
|
|
||||||
r"""
|
|
||||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(CausalAttnProcessor2_0, self).__init__()
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn: Attention,
|
|
||||||
hidden_states: torch.FloatTensor,
|
|
||||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
temb: Optional[torch.FloatTensor] = None,
|
|
||||||
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.FloatTensor, torch.Tensor]:
|
|
||||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
|
||||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
|
|
||||||
`scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
|
||||||
deprecate("scale", "1.0.0", deprecation_message)
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
# NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
|
|
||||||
# scaled_dot_product_attention expects attention_mask shape to be
|
|
||||||
# (batch, heads, source_length, target_length)
|
|
||||||
attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key_cache = attn.to_k(encoder_hidden_states)
|
|
||||||
value_cache = attn.to_v(encoder_hidden_states)
|
|
||||||
# NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
|
|
||||||
if cache.size(0) != 0:
|
|
||||||
key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
|
|
||||||
value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
|
|
||||||
else:
|
|
||||||
key, value = key_cache, value_cache
|
|
||||||
cache = torch.stack([key_cache, value_cache], dim=3)
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states, cache
|
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
|
||||||
class CausalAttention(Attention):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
query_dim: int,
|
|
||||||
cross_attention_dim: Optional[int] = None,
|
|
||||||
heads: int = 8,
|
|
||||||
dim_head: int = 64,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
bias: bool = False,
|
|
||||||
upcast_attention: bool = False,
|
|
||||||
upcast_softmax: bool = False,
|
|
||||||
cross_attention_norm: Optional[str] = None,
|
|
||||||
cross_attention_norm_num_groups: int = 32,
|
|
||||||
qk_norm: Optional[str] = None,
|
|
||||||
added_kv_proj_dim: Optional[int] = None,
|
|
||||||
norm_num_groups: Optional[int] = None,
|
|
||||||
spatial_norm_dim: Optional[int] = None,
|
|
||||||
out_bias: bool = True,
|
|
||||||
scale_qk: bool = True,
|
|
||||||
only_cross_attention: bool = False,
|
|
||||||
eps: float = 1e-5,
|
|
||||||
rescale_output_factor: float = 1.0,
|
|
||||||
residual_connection: bool = False,
|
|
||||||
_from_deprecated_attn_block: bool = False,
|
|
||||||
processor: Optional["AttnProcessor2_0"] = None,
|
|
||||||
out_dim: int = None,
|
|
||||||
):
|
|
||||||
super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
|
|
||||||
cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
|
|
||||||
spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
|
|
||||||
_from_deprecated_attn_block, processor, out_dim)
|
|
||||||
processor = CausalAttnProcessor2_0()
|
|
||||||
self.set_processor(processor)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.FloatTensor,
|
|
||||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
||||||
**cross_attention_kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
r"""
|
|
||||||
The forward method of the `Attention` class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_states (`torch.Tensor`):
|
|
||||||
The hidden states of the query.
|
|
||||||
encoder_hidden_states (`torch.Tensor`, *optional*):
|
|
||||||
The hidden states of the encoder.
|
|
||||||
attention_mask (`torch.Tensor`, *optional*):
|
|
||||||
The attention mask to use. If `None`, no mask is applied.
|
|
||||||
**cross_attention_kwargs:
|
|
||||||
Additional keyword arguments to pass along to the cross attention.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`torch.Tensor`: The output of the attention layer.
|
|
||||||
"""
|
|
||||||
# The `Attention` class can call different attention processors / attention functions
|
|
||||||
# here we simply pass along all tensors to the selected processor class
|
|
||||||
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
|
||||||
|
|
||||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
|
||||||
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
|
|
||||||
if len(unused_kwargs) > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
|
||||||
)
|
|
||||||
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
|
||||||
|
|
||||||
return self.processor(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
cache=cache,
|
|
||||||
**cross_attention_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
|
||||||
class CausalBasicTransformerBlock(BasicTransformerBlock):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
num_attention_heads: int,
|
|
||||||
attention_head_dim: int,
|
|
||||||
dropout=0.0,
|
|
||||||
cross_attention_dim: Optional[int] = None,
|
|
||||||
activation_fn: str = "geglu",
|
|
||||||
num_embeds_ada_norm: Optional[int] = None,
|
|
||||||
attention_bias: bool = False,
|
|
||||||
only_cross_attention: bool = False,
|
|
||||||
double_self_attention: bool = False,
|
|
||||||
upcast_attention: bool = False,
|
|
||||||
norm_elementwise_affine: bool = True,
|
|
||||||
norm_type: str = "layer_norm",
|
|
||||||
final_dropout: bool = False,
|
|
||||||
):
|
|
||||||
super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
|
|
||||||
cross_attention_dim, activation_fn, num_embeds_ada_norm,
|
|
||||||
attention_bias, only_cross_attention, double_self_attention,
|
|
||||||
upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
|
|
||||||
self.attn1 = CausalAttention(
|
|
||||||
query_dim=dim,
|
|
||||||
heads=num_attention_heads,
|
|
||||||
dim_head=attention_head_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
bias=attention_bias,
|
|
||||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
|
||||||
upcast_attention=upcast_attention,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.FloatTensor,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
timestep: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_kwargs: Dict[str, Any] = None,
|
|
||||||
class_labels: Optional[torch.LongTensor] = None,
|
|
||||||
cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
|
||||||
# 1. Self-Attention
|
|
||||||
if self.use_ada_layer_norm:
|
|
||||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
|
||||||
elif self.use_ada_layer_norm_zero:
|
|
||||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
|
||||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
|
||||||
|
|
||||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
|
||||||
|
|
||||||
attn_output, cache = self.attn1(
|
|
||||||
norm_hidden_states,
|
|
||||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
|
||||||
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
|
||||||
cache=cache,
|
|
||||||
**cross_attention_kwargs,
|
|
||||||
)
|
|
||||||
if self.use_ada_layer_norm_zero:
|
|
||||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
||||||
hidden_states = attn_output + hidden_states
|
|
||||||
|
|
||||||
# 2. Cross-Attention
|
|
||||||
if self.attn2 is not None:
|
|
||||||
norm_hidden_states = (
|
|
||||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.attn2(
|
|
||||||
norm_hidden_states,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
attention_mask=encoder_attention_mask,
|
|
||||||
**cross_attention_kwargs,
|
|
||||||
)
|
|
||||||
hidden_states = attn_output + hidden_states
|
|
||||||
|
|
||||||
# 3. Feed-forward
|
|
||||||
norm_hidden_states = self.norm3(hidden_states)
|
|
||||||
|
|
||||||
if self.use_ada_layer_norm_zero:
|
|
||||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
||||||
|
|
||||||
if self._chunk_size is not None:
|
|
||||||
# "feed_forward_chunk_size" can be used to save memory
|
|
||||||
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
|
||||||
raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
|
|
||||||
{self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
|
|
||||||
|
|
||||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
|
||||||
ff_output = torch.cat(
|
|
||||||
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
|
||||||
dim=self._chunk_dim,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ff_output = self.ff(norm_hidden_states)
|
|
||||||
|
|
||||||
if self.use_ada_layer_norm_zero:
|
|
||||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
||||||
|
|
||||||
hidden_states = ff_output + hidden_states
|
|
||||||
|
|
||||||
return hidden_states, cache
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionalDecoder(nn.Module):
|
class ConditionalDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -640,7 +335,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|||||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||||
transformer_blocks = nn.ModuleList(
|
transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
CausalBasicTransformerBlock(
|
BasicTransformerBlock(
|
||||||
dim=output_channel,
|
dim=output_channel,
|
||||||
num_attention_heads=num_heads,
|
num_attention_heads=num_heads,
|
||||||
attention_head_dim=attention_head_dim,
|
attention_head_dim=attention_head_dim,
|
||||||
@@ -662,7 +357,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|||||||
|
|
||||||
transformer_blocks = nn.ModuleList(
|
transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
CausalBasicTransformerBlock(
|
BasicTransformerBlock(
|
||||||
dim=output_channel,
|
dim=output_channel,
|
||||||
num_attention_heads=num_heads,
|
num_attention_heads=num_heads,
|
||||||
attention_head_dim=attention_head_dim,
|
attention_head_dim=attention_head_dim,
|
||||||
@@ -687,7 +382,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|||||||
)
|
)
|
||||||
transformer_blocks = nn.ModuleList(
|
transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
CausalBasicTransformerBlock(
|
BasicTransformerBlock(
|
||||||
dim=output_channel,
|
dim=output_channel,
|
||||||
num_attention_heads=num_heads,
|
num_attention_heads=num_heads,
|
||||||
attention_head_dim=attention_head_dim,
|
attention_head_dim=attention_head_dim,
|
||||||
@@ -724,6 +419,9 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|||||||
Returns:
|
Returns:
|
||||||
_type_: _description_
|
_type_: _description_
|
||||||
"""
|
"""
|
||||||
|
if hasattr(self, 'streaming'):
|
||||||
|
assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
|
||||||
|
streaming = self.streaming
|
||||||
|
|
||||||
t = self.time_embeddings(t).to(t.dtype)
|
t = self.time_embeddings(t).to(t.dtype)
|
||||||
t = self.time_mlp(t)
|
t = self.time_mlp(t)
|
||||||
@@ -740,36 +438,36 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|||||||
masks = [mask]
|
masks = [mask]
|
||||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||||
mask_down = masks[-1]
|
mask_down = masks[-1]
|
||||||
x, _, _ = resnet(x, mask_down, t)
|
x = resnet(x, mask_down, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
if streaming is True:
|
if streaming is True:
|
||||||
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||||
else:
|
else:
|
||||||
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x, _ = transformer_block(
|
x = transformer_block(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
attention_mask=attn_mask,
|
attention_mask=attn_mask,
|
||||||
timestep=t,
|
timestep=t,
|
||||||
)
|
)
|
||||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||||
hiddens.append(x) # Save hidden states for skip connections
|
hiddens.append(x) # Save hidden states for skip connections
|
||||||
x, _ = downsample(x * mask_down)
|
x = downsample(x * mask_down)
|
||||||
masks.append(mask_down[:, :, ::2])
|
masks.append(mask_down[:, :, ::2])
|
||||||
masks = masks[:-1]
|
masks = masks[:-1]
|
||||||
mask_mid = masks[-1]
|
mask_mid = masks[-1]
|
||||||
|
|
||||||
for resnet, transformer_blocks in self.mid_blocks:
|
for resnet, transformer_blocks in self.mid_blocks:
|
||||||
x, _, _ = resnet(x, mask_mid, t)
|
x = resnet(x, mask_mid, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
if streaming is True:
|
if streaming is True:
|
||||||
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||||
else:
|
else:
|
||||||
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x, _ = transformer_block(
|
x = transformer_block(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
attention_mask=attn_mask,
|
attention_mask=attn_mask,
|
||||||
timestep=t,
|
timestep=t,
|
||||||
@@ -780,124 +478,21 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|||||||
mask_up = masks.pop()
|
mask_up = masks.pop()
|
||||||
skip = hiddens.pop()
|
skip = hiddens.pop()
|
||||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||||
x, _, _ = resnet(x, mask_up, t)
|
x = resnet(x, mask_up, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
if streaming is True:
|
if streaming is True:
|
||||||
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||||
else:
|
else:
|
||||||
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x, _ = transformer_block(
|
x = transformer_block(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
attention_mask=attn_mask,
|
attention_mask=attn_mask,
|
||||||
timestep=t,
|
timestep=t,
|
||||||
)
|
)
|
||||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||||
x, _ = upsample(x * mask_up)
|
x = upsample(x * mask_up)
|
||||||
x, _ = self.final_block(x, mask_up)
|
x = self.final_block(x, mask_up)
|
||||||
output = self.final_proj(x * mask_up)
|
output = self.final_proj(x * mask_up)
|
||||||
return output * mask
|
return output * mask
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
|
|
||||||
down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
||||||
down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
|
||||||
mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
||||||
mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
|
||||||
up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
||||||
up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
|
||||||
final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""Forward pass of the UNet1DConditional model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
|
||||||
mask (_type_): shape (batch_size, 1, time)
|
|
||||||
t (_type_): shape (batch_size)
|
|
||||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
|
||||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: _description_
|
|
||||||
ValueError: _description_
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
_type_: _description_
|
|
||||||
"""
|
|
||||||
|
|
||||||
t = self.time_embeddings(t).to(t.dtype)
|
|
||||||
t = self.time_mlp(t)
|
|
||||||
|
|
||||||
x = pack([x, mu], "b * t")[0]
|
|
||||||
|
|
||||||
if spks is not None:
|
|
||||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
|
||||||
x = pack([x, spks], "b * t")[0]
|
|
||||||
if cond is not None:
|
|
||||||
x = pack([x, cond], "b * t")[0]
|
|
||||||
|
|
||||||
hiddens = []
|
|
||||||
masks = [mask]
|
|
||||||
|
|
||||||
down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
|
|
||||||
mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
|
|
||||||
up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
|
|
||||||
for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
|
|
||||||
mask_down = masks[-1]
|
|
||||||
x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
|
|
||||||
resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
|
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
||||||
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
|
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
||||||
for i, transformer_block in enumerate(transformer_blocks):
|
|
||||||
x, down_blocks_kv_cache_new[index, i] = transformer_block(
|
|
||||||
hidden_states=x,
|
|
||||||
attention_mask=attn_mask,
|
|
||||||
timestep=t,
|
|
||||||
cache=down_blocks_kv_cache[index, i],
|
|
||||||
)
|
|
||||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
||||||
hiddens.append(x) # Save hidden states for skip connections
|
|
||||||
x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
|
|
||||||
masks.append(mask_down[:, :, ::2])
|
|
||||||
masks = masks[:-1]
|
|
||||||
mask_mid = masks[-1]
|
|
||||||
|
|
||||||
for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
|
|
||||||
x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
|
|
||||||
resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
|
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
||||||
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
|
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
||||||
for i, transformer_block in enumerate(transformer_blocks):
|
|
||||||
x, mid_blocks_kv_cache_new[index, i] = transformer_block(
|
|
||||||
hidden_states=x,
|
|
||||||
attention_mask=attn_mask,
|
|
||||||
timestep=t,
|
|
||||||
cache=mid_blocks_kv_cache[index, i]
|
|
||||||
)
|
|
||||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
||||||
|
|
||||||
for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
|
|
||||||
mask_up = masks.pop()
|
|
||||||
skip = hiddens.pop()
|
|
||||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
||||||
x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
|
|
||||||
resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
|
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
||||||
attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
|
|
||||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
||||||
for i, transformer_block in enumerate(transformer_blocks):
|
|
||||||
x, up_blocks_kv_cache_new[index, i] = transformer_block(
|
|
||||||
hidden_states=x,
|
|
||||||
attention_mask=attn_mask,
|
|
||||||
timestep=t,
|
|
||||||
cache=up_blocks_kv_cache[index, i]
|
|
||||||
)
|
|
||||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
|
||||||
x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
|
|
||||||
x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
|
|
||||||
output = self.final_proj(x * mask_up)
|
|
||||||
return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
|
|
||||||
up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache
|
|
||||||
|
|||||||
@@ -241,7 +241,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
prompt_feat,
|
prompt_feat,
|
||||||
prompt_feat_len,
|
prompt_feat_len,
|
||||||
embedding,
|
embedding,
|
||||||
cache,
|
|
||||||
finalize):
|
finalize):
|
||||||
assert token.shape[0] == 1
|
assert token.shape[0] == 1
|
||||||
# xvec projection
|
# xvec projection
|
||||||
@@ -255,16 +254,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
|
|
||||||
# text encode
|
# text encode
|
||||||
if finalize is True:
|
if finalize is True:
|
||||||
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
|
h, h_lengths = self.encoder(token, token_len)
|
||||||
else:
|
else:
|
||||||
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
||||||
h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
|
h, h_lengths = self.encoder(token, token_len, context=context)
|
||||||
cache['encoder_cache']['offset'] = encoder_cache[0]
|
|
||||||
cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
|
|
||||||
cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
|
|
||||||
cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
|
|
||||||
cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
|
|
||||||
cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
|
|
||||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||||
h = self.encoder_proj(h)
|
h = self.encoder_proj(h)
|
||||||
|
|
||||||
@@ -274,14 +267,13 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
conds = conds.transpose(1, 2)
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||||
feat, cache['decoder_cache'] = self.decoder(
|
feat, _ = self.decoder(
|
||||||
mu=h.transpose(1, 2).contiguous(),
|
mu=h.transpose(1, 2).contiguous(),
|
||||||
mask=mask.unsqueeze(1),
|
mask=mask.unsqueeze(1),
|
||||||
spks=embedding,
|
spks=embedding,
|
||||||
cond=conds,
|
cond=conds,
|
||||||
n_timesteps=10,
|
n_timesteps=10,
|
||||||
cache=cache['decoder_cache']
|
|
||||||
)
|
)
|
||||||
feat = feat[:, :, mel_len1:]
|
feat = feat[:, :, mel_len1:]
|
||||||
assert feat.shape[2] == mel_len2
|
assert feat.shape[2] == mel_len2
|
||||||
return feat.float(), cache
|
return feat.float(), None
|
||||||
|
|||||||
@@ -126,21 +126,26 @@ class ConditionalCFM(BASECFM):
|
|||||||
if isinstance(self.estimator, torch.nn.Module):
|
if isinstance(self.estimator, torch.nn.Module):
|
||||||
return self.estimator(x, mask, mu, t, spks, cond)
|
return self.estimator(x, mask, mu, t, spks, cond)
|
||||||
else:
|
else:
|
||||||
with self.lock:
|
estimator, trt_engine = self.estimator.acquire_estimator()
|
||||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||||
self.estimator.set_input_shape('t', (2,))
|
estimator.set_input_shape('t', (2,))
|
||||||
self.estimator.set_input_shape('spks', (2, 80))
|
estimator.set_input_shape('spks', (2, 80))
|
||||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||||
# run trt engine
|
data_ptrs = [x.contiguous().data_ptr(),
|
||||||
assert self.estimator.execute_v2([x.contiguous().data_ptr(),
|
mask.contiguous().data_ptr(),
|
||||||
mask.contiguous().data_ptr(),
|
mu.contiguous().data_ptr(),
|
||||||
mu.contiguous().data_ptr(),
|
t.contiguous().data_ptr(),
|
||||||
t.contiguous().data_ptr(),
|
spks.contiguous().data_ptr(),
|
||||||
spks.contiguous().data_ptr(),
|
cond.contiguous().data_ptr(),
|
||||||
cond.contiguous().data_ptr(),
|
x.data_ptr()]
|
||||||
x.data_ptr()]) is True
|
for i, j in enumerate(data_ptrs):
|
||||||
|
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||||
|
# run trt engine
|
||||||
|
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
self.estimator.release_estimator(estimator)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
||||||
@@ -191,7 +196,7 @@ class CausalConditionalCFM(ConditionalCFM):
|
|||||||
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||||
"""Forward diffusion
|
"""Forward diffusion
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -210,136 +215,9 @@ class CausalConditionalCFM(ConditionalCFM):
|
|||||||
shape: (batch_size, n_feats, mel_timesteps)
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
offset = cache.pop('offset')
|
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
||||||
z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
|
|
||||||
z = z[:, :, offset:]
|
|
||||||
offset += mu.size(2)
|
|
||||||
# fix prompt and overlap part mu and z
|
# fix prompt and overlap part mu and z
|
||||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||||
if self.t_scheduler == 'cosine':
|
if self.t_scheduler == 'cosine':
|
||||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||||
mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
|
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
||||||
cache['offset'] = offset
|
|
||||||
return mel, cache
|
|
||||||
|
|
||||||
def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
|
|
||||||
"""
|
|
||||||
Fixed euler solver for ODEs.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): random noise
|
|
||||||
t_span (torch.Tensor): n_timesteps interpolated
|
|
||||||
shape: (n_timesteps + 1,)
|
|
||||||
mu (torch.Tensor): output of encoder
|
|
||||||
shape: (batch_size, n_feats, mel_timesteps)
|
|
||||||
mask (torch.Tensor): output_mask
|
|
||||||
shape: (batch_size, 1, mel_timesteps)
|
|
||||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
||||||
shape: (batch_size, spk_emb_dim)
|
|
||||||
cond: Not used but kept for future purposes
|
|
||||||
"""
|
|
||||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
||||||
t = t.unsqueeze(dim=0)
|
|
||||||
|
|
||||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
||||||
# Or in future might add like a return_all_steps flag
|
|
||||||
sol = []
|
|
||||||
|
|
||||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
|
||||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
||||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
|
||||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
||||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
|
||||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
|
||||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
||||||
flow_cache_size = cache['down_blocks_kv_cache'].shape[4]
|
|
||||||
for step in range(1, len(t_span)):
|
|
||||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
|
||||||
x_in[:] = x
|
|
||||||
mask_in[:] = mask
|
|
||||||
mu_in[0] = mu
|
|
||||||
t_in[:] = t.unsqueeze(0)
|
|
||||||
spks_in[0] = spks
|
|
||||||
cond_in[0] = cond
|
|
||||||
cache_step = {k: v[step - 1] for k, v in cache.items()}
|
|
||||||
dphi_dt, cache_step = self.forward_estimator(
|
|
||||||
x_in, mask_in,
|
|
||||||
mu_in, t_in,
|
|
||||||
spks_in,
|
|
||||||
cond_in,
|
|
||||||
cache_step
|
|
||||||
)
|
|
||||||
# NOTE if smaller than flow_cache_size, means last chunk, no need to cache
|
|
||||||
if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
|
|
||||||
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
|
|
||||||
cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
|
|
||||||
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
|
|
||||||
cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
|
|
||||||
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
|
|
||||||
cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
|
|
||||||
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
|
|
||||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
|
||||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
||||||
x = x + dt * dphi_dt
|
|
||||||
t = t + dt
|
|
||||||
sol.append(x)
|
|
||||||
if step < len(t_span) - 1:
|
|
||||||
dt = t_span[step + 1] - t
|
|
||||||
return sol[-1].float(), cache
|
|
||||||
|
|
||||||
def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
|
|
||||||
if isinstance(self.estimator, torch.nn.Module):
|
|
||||||
x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
|
|
||||||
cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
|
|
||||||
else:
|
|
||||||
estimator, trt_engine = self.estimator.acquire_estimator()
|
|
||||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
|
||||||
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
|
||||||
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
|
||||||
estimator.set_input_shape('t', (2,))
|
|
||||||
estimator.set_input_shape('spks', (2, 80))
|
|
||||||
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
|
||||||
estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
|
|
||||||
estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
|
|
||||||
estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
|
|
||||||
estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
|
|
||||||
estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
|
|
||||||
estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
|
|
||||||
estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
|
|
||||||
down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
|
||||||
mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
|
|
||||||
up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
|
|
||||||
data_ptrs = [x.contiguous().data_ptr(),
|
|
||||||
mask.contiguous().data_ptr(),
|
|
||||||
mu.contiguous().data_ptr(),
|
|
||||||
t.contiguous().data_ptr(),
|
|
||||||
spks.contiguous().data_ptr(),
|
|
||||||
cond.contiguous().data_ptr(),
|
|
||||||
cache['down_blocks_conv_cache'].contiguous().data_ptr(),
|
|
||||||
cache['down_blocks_kv_cache'].contiguous().data_ptr(),
|
|
||||||
cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
|
|
||||||
cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
|
|
||||||
cache['up_blocks_conv_cache'].contiguous().data_ptr(),
|
|
||||||
cache['up_blocks_kv_cache'].contiguous().data_ptr(),
|
|
||||||
cache['final_blocks_conv_cache'].contiguous().data_ptr(),
|
|
||||||
x.data_ptr(),
|
|
||||||
cache['down_blocks_conv_cache'].data_ptr(),
|
|
||||||
down_blocks_kv_cache_out.data_ptr(),
|
|
||||||
cache['mid_blocks_conv_cache'].data_ptr(),
|
|
||||||
mid_blocks_kv_cache_out.data_ptr(),
|
|
||||||
cache['up_blocks_conv_cache'].data_ptr(),
|
|
||||||
up_blocks_kv_cache_out.data_ptr(),
|
|
||||||
cache['final_blocks_conv_cache'].data_ptr()]
|
|
||||||
for i, j in enumerate(data_ptrs):
|
|
||||||
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
|
||||||
# run trt engine
|
|
||||||
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
|
||||||
torch.cuda.current_stream().synchronize()
|
|
||||||
self.estimator.release_estimator(estimator)
|
|
||||||
cache = (cache['down_blocks_conv_cache'],
|
|
||||||
down_blocks_kv_cache_out,
|
|
||||||
cache['mid_blocks_conv_cache'],
|
|
||||||
mid_blocks_kv_cache_out,
|
|
||||||
cache['up_blocks_conv_cache'],
|
|
||||||
up_blocks_kv_cache_out,
|
|
||||||
cache['final_blocks_conv_cache'])
|
|
||||||
return x, cache
|
|
||||||
|
|||||||
@@ -223,6 +223,172 @@ class SourceModuleHnNSF(torch.nn.Module):
|
|||||||
return sine_merge, noise, uv
|
return sine_merge, noise, uv
|
||||||
|
|
||||||
|
|
||||||
|
class SineGen2(torch.nn.Module):
|
||||||
|
""" Definition of sine generator
|
||||||
|
SineGen(samp_rate, harmonic_num = 0,
|
||||||
|
sine_amp = 0.1, noise_std = 0.003,
|
||||||
|
voiced_threshold = 0,
|
||||||
|
flag_for_pulse=False)
|
||||||
|
samp_rate: sampling rate in Hz
|
||||||
|
harmonic_num: number of harmonic overtones (default 0)
|
||||||
|
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||||
|
noise_std: std of Gaussian noise (default 0.003)
|
||||||
|
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||||
|
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||||
|
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||||
|
segment is always sin(np.pi) or cos(0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||||
|
sine_amp=0.1, noise_std=0.003,
|
||||||
|
voiced_threshold=0,
|
||||||
|
flag_for_pulse=False):
|
||||||
|
super(SineGen2, self).__init__()
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = noise_std
|
||||||
|
self.harmonic_num = harmonic_num
|
||||||
|
self.dim = self.harmonic_num + 1
|
||||||
|
self.sampling_rate = samp_rate
|
||||||
|
self.voiced_threshold = voiced_threshold
|
||||||
|
self.flag_for_pulse = flag_for_pulse
|
||||||
|
self.upsample_scale = upsample_scale
|
||||||
|
|
||||||
|
def _f02uv(self, f0):
|
||||||
|
# generate uv signal
|
||||||
|
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||||
|
return uv
|
||||||
|
|
||||||
|
def _f02sine(self, f0_values):
|
||||||
|
""" f0_values: (batchsize, length, dim)
|
||||||
|
where dim indicates fundamental tone and overtones
|
||||||
|
"""
|
||||||
|
# convert to F0 in rad. The interger part n can be ignored
|
||||||
|
# because 2 * np.pi * n doesn't affect phase
|
||||||
|
rad_values = (f0_values / self.sampling_rate) % 1
|
||||||
|
|
||||||
|
# initial phase noise (no noise for fundamental component)
|
||||||
|
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
||||||
|
rand_ini[:, 0] = 0
|
||||||
|
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||||
|
|
||||||
|
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||||
|
if not self.flag_for_pulse:
|
||||||
|
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
||||||
|
scale_factor=1 / self.upsample_scale,
|
||||||
|
mode="linear").transpose(1, 2)
|
||||||
|
|
||||||
|
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||||
|
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
||||||
|
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
||||||
|
sines = torch.sin(phase)
|
||||||
|
else:
|
||||||
|
# If necessary, make sure that the first time step of every
|
||||||
|
# voiced segments is sin(pi) or cos(0)
|
||||||
|
# This is used for pulse-train generation
|
||||||
|
|
||||||
|
# identify the last time step in unvoiced segments
|
||||||
|
uv = self._f02uv(f0_values)
|
||||||
|
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||||
|
uv_1[:, -1, :] = 1
|
||||||
|
u_loc = (uv < 1) * (uv_1 > 0)
|
||||||
|
|
||||||
|
# get the instantanouse phase
|
||||||
|
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||||
|
# different batch needs to be processed differently
|
||||||
|
for idx in range(f0_values.shape[0]):
|
||||||
|
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||||
|
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||||
|
# stores the accumulation of i.phase within
|
||||||
|
# each voiced segments
|
||||||
|
tmp_cumsum[idx, :, :] = 0
|
||||||
|
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||||
|
|
||||||
|
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||||
|
# within the previous voiced segment.
|
||||||
|
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||||
|
|
||||||
|
# get the sines
|
||||||
|
sines = torch.cos(i_phase * 2 * np.pi)
|
||||||
|
return sines
|
||||||
|
|
||||||
|
def forward(self, f0):
|
||||||
|
""" sine_tensor, uv = forward(f0)
|
||||||
|
input F0: tensor(batchsize=1, length, dim=1)
|
||||||
|
f0 for unvoiced steps should be 0
|
||||||
|
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||||
|
output uv: tensor(batchsize=1, length, 1)
|
||||||
|
"""
|
||||||
|
# fundamental component
|
||||||
|
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||||
|
|
||||||
|
# generate sine waveforms
|
||||||
|
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||||
|
|
||||||
|
# generate uv signal
|
||||||
|
uv = self._f02uv(f0)
|
||||||
|
|
||||||
|
# noise: for unvoiced should be similar to sine_amp
|
||||||
|
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||||
|
# . for voiced regions is self.noise_std
|
||||||
|
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||||
|
noise = noise_amp * torch.randn_like(sine_waves)
|
||||||
|
|
||||||
|
# first: set the unvoiced part to 0 by uv
|
||||||
|
# then: additive noise
|
||||||
|
sine_waves = sine_waves * uv + noise
|
||||||
|
return sine_waves, uv, noise
|
||||||
|
|
||||||
|
|
||||||
|
class SourceModuleHnNSF2(torch.nn.Module):
|
||||||
|
""" SourceModule for hn-nsf
|
||||||
|
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||||
|
add_noise_std=0.003, voiced_threshod=0)
|
||||||
|
sampling_rate: sampling_rate in Hz
|
||||||
|
harmonic_num: number of harmonic above F0 (default: 0)
|
||||||
|
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||||
|
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||||
|
note that amplitude of noise in unvoiced is decided
|
||||||
|
by sine_amp
|
||||||
|
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||||
|
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||||
|
F0_sampled (batchsize, length, 1)
|
||||||
|
Sine_source (batchsize, length, 1)
|
||||||
|
noise_source (batchsize, length 1)
|
||||||
|
uv (batchsize, length, 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||||
|
add_noise_std=0.003, voiced_threshod=0):
|
||||||
|
super(SourceModuleHnNSF2, self).__init__()
|
||||||
|
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = add_noise_std
|
||||||
|
|
||||||
|
# to produce sine waveforms
|
||||||
|
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
|
||||||
|
sine_amp, add_noise_std, voiced_threshod)
|
||||||
|
|
||||||
|
# to merge source harmonics into a single excitation
|
||||||
|
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||||
|
self.l_tanh = torch.nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||||
|
F0_sampled (batchsize, length, 1)
|
||||||
|
Sine_source (batchsize, length, 1)
|
||||||
|
noise_source (batchsize, length 1)
|
||||||
|
"""
|
||||||
|
# source for harmonic branch
|
||||||
|
with torch.no_grad():
|
||||||
|
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||||
|
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||||
|
|
||||||
|
# source for noise branch, in the same shape as uv
|
||||||
|
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||||
|
return sine_merge, noise, uv
|
||||||
|
|
||||||
|
|
||||||
class HiFTGenerator(nn.Module):
|
class HiFTGenerator(nn.Module):
|
||||||
"""
|
"""
|
||||||
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
||||||
@@ -259,7 +425,9 @@ class HiFTGenerator(nn.Module):
|
|||||||
|
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_rates)
|
self.num_upsamples = len(upsample_rates)
|
||||||
self.m_source = SourceModuleHnNSF(
|
# NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
|
||||||
|
this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
|
||||||
|
self.m_source = this_SourceModuleHnNSF(
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
||||||
harmonic_num=nb_harmonics,
|
harmonic_num=nb_harmonics,
|
||||||
|
|||||||
@@ -56,16 +56,11 @@ class Upsample1D(nn.Module):
|
|||||||
# In this mode, first repeat interpolate, than conv with stride=1
|
# In this mode, first repeat interpolate, than conv with stride=1
|
||||||
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
||||||
if conv_cache.size(2) == 0:
|
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
||||||
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
|
||||||
else:
|
|
||||||
assert conv_cache.size(2) == self.stride * 2
|
|
||||||
outputs = torch.concat([conv_cache, outputs], dim=2)
|
|
||||||
conv_cache_new = outputs[:, :, -self.stride * 2:]
|
|
||||||
outputs = self.conv(outputs)
|
outputs = self.conv(outputs)
|
||||||
return outputs, input_lengths * self.stride, conv_cache_new
|
return outputs, input_lengths * self.stride
|
||||||
|
|
||||||
|
|
||||||
class PreLookaheadLayer(nn.Module):
|
class PreLookaheadLayer(nn.Module):
|
||||||
@@ -83,7 +78,7 @@ class PreLookaheadLayer(nn.Module):
|
|||||||
kernel_size=3, stride=1, padding=0,
|
kernel_size=3, stride=1, padding=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
inputs: (batch_size, seq_len, channels)
|
inputs: (batch_size, seq_len, channels)
|
||||||
"""
|
"""
|
||||||
@@ -93,22 +88,18 @@ class PreLookaheadLayer(nn.Module):
|
|||||||
if context.size(2) == 0:
|
if context.size(2) == 0:
|
||||||
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
||||||
else:
|
else:
|
||||||
|
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
||||||
assert context.size(2) == self.pre_lookahead_len
|
assert context.size(2) == self.pre_lookahead_len
|
||||||
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
|
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
|
||||||
outputs = F.leaky_relu(self.conv1(outputs))
|
outputs = F.leaky_relu(self.conv1(outputs))
|
||||||
# outputs
|
# outputs
|
||||||
if conv2_cache.size(2) == 0:
|
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
||||||
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
|
||||||
else:
|
|
||||||
assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1
|
|
||||||
outputs = torch.concat([conv2_cache, outputs], dim=2)
|
|
||||||
conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):]
|
|
||||||
outputs = self.conv2(outputs)
|
outputs = self.conv2(outputs)
|
||||||
outputs = outputs.transpose(1, 2).contiguous()
|
outputs = outputs.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
# residual connection
|
# residual connection
|
||||||
outputs = outputs + inputs
|
outputs = outputs + inputs
|
||||||
return outputs, conv2_cache_new
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class UpsampleConformerEncoder(torch.nn.Module):
|
class UpsampleConformerEncoder(torch.nn.Module):
|
||||||
@@ -253,6 +244,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
xs: torch.Tensor,
|
xs: torch.Tensor,
|
||||||
xs_lens: torch.Tensor,
|
xs_lens: torch.Tensor,
|
||||||
|
context: torch.Tensor = torch.zeros(0, 0, 0),
|
||||||
decoding_chunk_size: int = 0,
|
decoding_chunk_size: int = 0,
|
||||||
num_decoding_left_chunks: int = -1,
|
num_decoding_left_chunks: int = -1,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
@@ -280,20 +272,27 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|||||||
checkpointing API because `__call__` attaches all the hooks of the module.
|
checkpointing API because `__call__` attaches all the hooks of the module.
|
||||||
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||||
"""
|
"""
|
||||||
|
if hasattr(self, 'streaming'):
|
||||||
|
assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
|
||||||
|
streaming = self.streaming
|
||||||
T = xs.size(1)
|
T = xs.size(1)
|
||||||
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
xs = self.global_cmvn(xs)
|
xs = self.global_cmvn(xs)
|
||||||
xs, pos_emb, masks = self.embed(xs, masks)
|
xs, pos_emb, masks = self.embed(xs, masks)
|
||||||
|
if context.size(1) != 0:
|
||||||
|
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
||||||
|
context_masks = torch.ones(1, 1, context.size(1)).to(masks)
|
||||||
|
context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
|
||||||
mask_pad = masks # (B, 1, T/subsample_rate)
|
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||||
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
||||||
# lookahead + conformer encoder
|
# lookahead + conformer encoder
|
||||||
xs, _ = self.pre_lookahead_layer(xs)
|
xs = self.pre_lookahead_layer(xs, context=context)
|
||||||
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||||
|
|
||||||
# upsample + conformer encoder
|
# upsample + conformer encoder
|
||||||
xs = xs.transpose(1, 2).contiguous()
|
xs = xs.transpose(1, 2).contiguous()
|
||||||
xs, xs_lens, _ = self.up_layer(xs, xs_lens)
|
xs, xs_lens = self.up_layer(xs, xs_lens)
|
||||||
xs = xs.transpose(1, 2).contiguous()
|
xs = xs.transpose(1, 2).contiguous()
|
||||||
T = xs.size(1)
|
T = xs.size(1)
|
||||||
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
||||||
@@ -322,100 +321,3 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|||||||
for layer in self.up_encoders:
|
for layer in self.up_encoders:
|
||||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||||
return xs
|
return xs
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def forward_chunk(
|
|
||||||
self,
|
|
||||||
xs: torch.Tensor,
|
|
||||||
xs_lens: torch.Tensor,
|
|
||||||
offset: int = 0,
|
|
||||||
context: torch.Tensor = torch.zeros(0, 0, 0),
|
|
||||||
pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0),
|
|
||||||
encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0),
|
|
||||||
upsample_offset: int = 0,
|
|
||||||
upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0),
|
|
||||||
upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]:
|
|
||||||
"""Embed positions in tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
xs: padded input tensor (B, T, D)
|
|
||||||
xs_lens: input length (B)
|
|
||||||
decoding_chunk_size: decoding chunk size for dynamic chunk
|
|
||||||
0: default for training, use random dynamic chunk.
|
|
||||||
<0: for decoding, use full chunk.
|
|
||||||
>0: for decoding, use fixed chunk size as set.
|
|
||||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
||||||
the chunk size is decoding_chunk_size.
|
|
||||||
>=0: use num_decoding_left_chunks
|
|
||||||
<0: use all left chunks
|
|
||||||
Returns:
|
|
||||||
encoder output tensor xs, and subsampled masks
|
|
||||||
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
|
||||||
masks: torch.Tensor batch padding mask after subsample
|
|
||||||
(B, 1, T' ~= T/subsample_rate)
|
|
||||||
NOTE(xcsong):
|
|
||||||
We pass the `__call__` method of the modules instead of `forward` to the
|
|
||||||
checkpointing API because `__call__` attaches all the hooks of the module.
|
|
||||||
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
|
||||||
"""
|
|
||||||
assert xs.size(0) == 1
|
|
||||||
# tmp_masks is just for interface compatibility
|
|
||||||
tmp_masks = torch.ones(1,
|
|
||||||
xs.size(1),
|
|
||||||
device=xs.device,
|
|
||||||
dtype=torch.bool)
|
|
||||||
tmp_masks = tmp_masks.unsqueeze(1)
|
|
||||||
if self.global_cmvn is not None:
|
|
||||||
xs = self.global_cmvn(xs)
|
|
||||||
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
|
|
||||||
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
|
|
||||||
offset += xs.size(1)
|
|
||||||
tmp_masks = torch.ones(1,
|
|
||||||
context.size(1),
|
|
||||||
device=context.device,
|
|
||||||
dtype=torch.bool)
|
|
||||||
tmp_masks = tmp_masks.unsqueeze(1)
|
|
||||||
if context.size(1) != 0:
|
|
||||||
context, _, _ = self.embed(context, tmp_masks, offset)
|
|
||||||
|
|
||||||
# lookahead + conformer encoder
|
|
||||||
xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache)
|
|
||||||
# NOTE in cache mode we do not need to call add_optional_chunk_mask
|
|
||||||
chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device)
|
|
||||||
mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
|
|
||||||
encoders_kv_cache_list = []
|
|
||||||
for index, layer in enumerate(self.encoders):
|
|
||||||
xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
|
|
||||||
encoders_kv_cache_list.append(encoders_kv_cache_new)
|
|
||||||
encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
|
|
||||||
|
|
||||||
# upsample
|
|
||||||
xs = xs.transpose(1, 2).contiguous()
|
|
||||||
xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache)
|
|
||||||
xs = xs.transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
# tmp_masks is just for interface compatibility
|
|
||||||
tmp_masks = torch.ones(1,
|
|
||||||
xs.size(1),
|
|
||||||
device=xs.device,
|
|
||||||
dtype=torch.bool)
|
|
||||||
tmp_masks = tmp_masks.unsqueeze(1)
|
|
||||||
xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset)
|
|
||||||
upsample_offset += xs.size(1)
|
|
||||||
|
|
||||||
# conformer encoder
|
|
||||||
chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device)
|
|
||||||
mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
|
|
||||||
upsample_kv_cache_list = []
|
|
||||||
for index, layer in enumerate(self.up_encoders):
|
|
||||||
xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index])
|
|
||||||
upsample_kv_cache_list.append(upsample_kv_cache_new)
|
|
||||||
upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0)
|
|
||||||
|
|
||||||
if self.normalize_before:
|
|
||||||
xs = self.after_norm(xs)
|
|
||||||
# Here we assume the mask is not changed in encoder layers, so just
|
|
||||||
# return the masks before encoder layers, and the masks will be used
|
|
||||||
# for cross attention with decoder later
|
|
||||||
return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
|||||||
network = builder.create_network(network_flags)
|
network = builder.create_network(network_flags)
|
||||||
parser = trt.OnnxParser(network, logger)
|
parser = trt.OnnxParser(network, logger)
|
||||||
config = builder.create_builder_config()
|
config = builder.create_builder_config()
|
||||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
|
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31) # 1GB
|
||||||
if fp16:
|
if fp16:
|
||||||
config.set_flag(trt.BuilderFlag.FP16)
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
profile = builder.create_optimization_profile()
|
profile = builder.create_optimization_profile()
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ def subsequent_mask(
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def subsequent_chunk_mask(
|
def subsequent_chunk_mask_deprecated(
|
||||||
size: int,
|
size: int,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
num_left_chunks: int = -1,
|
num_left_chunks: int = -1,
|
||||||
@@ -124,6 +124,40 @@ def subsequent_chunk_mask(
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def subsequent_chunk_mask(
|
||||||
|
size: int,
|
||||||
|
chunk_size: int,
|
||||||
|
num_left_chunks: int = -1,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||||
|
this is for streaming encoder
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (int): size of mask
|
||||||
|
chunk_size (int): size of chunk
|
||||||
|
num_left_chunks (int): number of left chunks
|
||||||
|
<0: use full chunk
|
||||||
|
>=0: use num_left_chunks
|
||||||
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: mask
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> subsequent_chunk_mask(4, 2)
|
||||||
|
[[1, 1, 0, 0],
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1]]
|
||||||
|
"""
|
||||||
|
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
||||||
|
pos_idx = torch.arange(size, device=device)
|
||||||
|
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
||||||
|
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def add_optional_chunk_mask(xs: torch.Tensor,
|
def add_optional_chunk_mask(xs: torch.Tensor,
|
||||||
masks: torch.Tensor,
|
masks: torch.Tensor,
|
||||||
use_dynamic_chunk: bool,
|
use_dynamic_chunk: bool,
|
||||||
@@ -196,9 +230,6 @@ def add_optional_chunk_mask(xs: torch.Tensor,
|
|||||||
else:
|
else:
|
||||||
chunk_masks = masks
|
chunk_masks = masks
|
||||||
assert chunk_masks.dtype == torch.bool
|
assert chunk_masks.dtype == torch.bool
|
||||||
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
|
||||||
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
|
||||||
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
|
||||||
return chunk_masks
|
return chunk_masks
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
37
test1.py
37
test1.py
@@ -1,37 +0,0 @@
|
|||||||
import sys
|
|
||||||
sys.path.append('third_party/Matcha-TTS')
|
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
|
||||||
from cosyvoice.utils.file_utils import load_wav
|
|
||||||
import torchaudio # type: ignore
|
|
||||||
|
|
||||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)
|
|
||||||
|
|
||||||
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
|
||||||
# zero_shot usage
|
|
||||||
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
|
||||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
|
||||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
|
||||||
|
|
||||||
# save zero_shot spk for future usage
|
|
||||||
assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True
|
|
||||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)):
|
|
||||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
|
||||||
cosyvoice.save_spkinfo()
|
|
||||||
|
|
||||||
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
|
|
||||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
|
|
||||||
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
|
||||||
|
|
||||||
# instruct usage
|
|
||||||
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
|
|
||||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
|
||||||
|
|
||||||
# bistream usage, you can use generator as input, this is useful when using text llm model as input
|
|
||||||
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
|
|
||||||
def text_generator():
|
|
||||||
yield '收到好友从远方寄来的生日礼物,'
|
|
||||||
yield '那份意外的惊喜与深深的祝福'
|
|
||||||
yield '让我心中充满了甜蜜的快乐,'
|
|
||||||
yield '笑容如花儿般绽放。'
|
|
||||||
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
|
||||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
|
||||||
Reference in New Issue
Block a user