mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update
This commit is contained in:
@@ -132,7 +132,7 @@ import torchaudio
|
|||||||
|
|
||||||
**CosyVoice2 Usage**
|
**CosyVoice2 Usage**
|
||||||
```python
|
```python
|
||||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
|
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
|
||||||
|
|
||||||
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
||||||
# zero_shot usage
|
# zero_shot usage
|
||||||
@@ -151,7 +151,7 @@ for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来
|
|||||||
|
|
||||||
**CosyVoice Usage**
|
**CosyVoice Usage**
|
||||||
```python
|
```python
|
||||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
|
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
|
||||||
# sft usage
|
# sft usage
|
||||||
print(cosyvoice.list_available_spks())
|
print(cosyvoice.list_available_spks())
|
||||||
# change stream=True for chunk stream inference
|
# change stream=True for chunk stream inference
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import torch
|
|||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@@ -37,6 +37,15 @@ def get_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimized_script(model, preserved_attrs=[]):
|
||||||
|
script = torch.jit.script(model)
|
||||||
|
if preserved_attrs != []:
|
||||||
|
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
|
||||||
|
else:
|
||||||
|
script = torch.jit.freeze(script)
|
||||||
|
script = torch.jit.optimize_for_inference(script)
|
||||||
|
return script
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
@@ -46,28 +55,35 @@ def main():
|
|||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
torch._C._jit_set_profiling_executor(False)
|
torch._C._jit_set_profiling_executor(False)
|
||||||
|
|
||||||
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
try:
|
||||||
|
model = CosyVoice(args.model_dir)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
model = CosyVoice2(args.model_dir)
|
||||||
|
except:
|
||||||
|
raise TypeError('no valid model_type!')
|
||||||
|
|
||||||
# 1. export llm text_encoder
|
if not isinstance(model, CosyVoice2):
|
||||||
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
|
# 1. export llm text_encoder
|
||||||
script = torch.jit.script(llm_text_encoder)
|
llm_text_encoder = model.model.llm.text_encoder
|
||||||
script = torch.jit.freeze(script)
|
script = get_optimized_script(llm_text_encoder)
|
||||||
script = torch.jit.optimize_for_inference(script)
|
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
|
||||||
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
script = get_optimized_script(llm_text_encoder.half())
|
||||||
|
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
||||||
|
|
||||||
# 2. export llm llm
|
# 2. export llm llm
|
||||||
llm_llm = cosyvoice.model.llm.llm.half()
|
llm_llm = model.model.llm.llm
|
||||||
script = torch.jit.script(llm_llm)
|
script = get_optimized_script(llm_llm, ['forward_chunk'])
|
||||||
script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
|
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
|
||||||
script = torch.jit.optimize_for_inference(script)
|
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
|
||||||
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
||||||
|
|
||||||
# 3. export flow encoder
|
# 3. export flow encoder
|
||||||
flow_encoder = cosyvoice.model.flow.encoder
|
flow_encoder = model.model.flow.encoder
|
||||||
script = torch.jit.script(flow_encoder)
|
script = get_optimized_script(flow_encoder)
|
||||||
script = torch.jit.freeze(script)
|
|
||||||
script = torch.jit.optimize_for_inference(script)
|
|
||||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||||
|
script = get_optimized_script(flow_encoder.half())
|
||||||
|
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from tqdm import tqdm
|
|||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||||
|
|
||||||
|
|
||||||
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
||||||
@@ -56,14 +56,20 @@ def main():
|
|||||||
logging.basicConfig(level=logging.DEBUG,
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
format='%(asctime)s %(levelname)s %(message)s')
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
|
||||||
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
try:
|
||||||
|
model = CosyVoice(args.model_dir)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
model = CosyVoice2(args.model_dir)
|
||||||
|
except:
|
||||||
|
raise TypeError('no valid model_type!')
|
||||||
|
|
||||||
# 1. export flow decoder estimator
|
# 1. export flow decoder estimator
|
||||||
estimator = cosyvoice.model.flow.decoder.estimator
|
estimator = model.model.flow.decoder.estimator
|
||||||
|
|
||||||
device = cosyvoice.model.device
|
device = model.model.device
|
||||||
batch_size, seq_len = 1, 256
|
batch_size, seq_len = 2, 256
|
||||||
out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
|
out_channels = model.model.flow.decoder.estimator.out_channels
|
||||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
estimator,
|
estimator,
|
||||||
@@ -75,13 +81,11 @@ def main():
|
|||||||
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
||||||
output_names=['estimator_out'],
|
output_names=['estimator_out'],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
'x': {0: 'batch_size', 2: 'seq_len'},
|
'x': {2: 'seq_len'},
|
||||||
'mask': {0: 'batch_size', 2: 'seq_len'},
|
'mask': {2: 'seq_len'},
|
||||||
'mu': {0: 'batch_size', 2: 'seq_len'},
|
'mu': {2: 'seq_len'},
|
||||||
'cond': {0: 'batch_size', 2: 'seq_len'},
|
'cond': {2: 'seq_len'},
|
||||||
't': {0: 'batch_size'},
|
'estimator_out': {2: 'seq_len'},
|
||||||
'spks': {0: 'batch_size'},
|
|
||||||
'estimator_out': {0: 'batch_size', 2: 'seq_len'},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -94,7 +98,7 @@ def main():
|
|||||||
sess_options=option, providers=providers)
|
sess_options=option, providers=providers)
|
||||||
|
|
||||||
for _ in tqdm(range(10)):
|
for _ in tqdm(range(10)):
|
||||||
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
||||||
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
||||||
ort_inputs = {
|
ort_inputs = {
|
||||||
'x': x.cpu().numpy(),
|
'x': x.cpu().numpy(),
|
||||||
|
|||||||
@@ -6,4 +6,5 @@ TRT_DIR=<YOUR_TRT_DIR>
|
|||||||
MODEL_DIR=<COSYVOICE2_MODEL_DIR>
|
MODEL_DIR=<COSYVOICE2_MODEL_DIR>
|
||||||
|
|
||||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
|
||||||
|
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
|
||||||
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
|
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
|
||||||
|
|||||||
@@ -25,14 +25,15 @@ from cosyvoice.utils.class_utils import get_model_type
|
|||||||
|
|
||||||
class CosyVoice:
|
class CosyVoice:
|
||||||
|
|
||||||
def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
|
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
||||||
self.instruct = True if '-Instruct' in model_dir else False
|
self.instruct = True if '-Instruct' in model_dir else False
|
||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
|
self.fp16 = fp16
|
||||||
if not os.path.exists(model_dir):
|
if not os.path.exists(model_dir):
|
||||||
model_dir = snapshot_download(model_dir)
|
model_dir = snapshot_download(model_dir)
|
||||||
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
||||||
configs = load_hyperpyyaml(f)
|
configs = load_hyperpyyaml(f)
|
||||||
assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
||||||
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||||
configs['feat_extractor'],
|
configs['feat_extractor'],
|
||||||
'{}/campplus.onnx'.format(model_dir),
|
'{}/campplus.onnx'.format(model_dir),
|
||||||
@@ -40,20 +41,19 @@ class CosyVoice:
|
|||||||
'{}/spk2info.pt'.format(model_dir),
|
'{}/spk2info.pt'.format(model_dir),
|
||||||
configs['allowed_special'])
|
configs['allowed_special'])
|
||||||
self.sample_rate = configs['sample_rate']
|
self.sample_rate = configs['sample_rate']
|
||||||
if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
|
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||||
load_jit = False
|
load_jit, load_trt, fp16 = False, False, False
|
||||||
fp16 = False
|
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||||
logging.warning('cpu do not support fp16 and jit, force set to False')
|
|
||||||
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/hift.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
if load_jit:
|
if load_jit:
|
||||||
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
'{}/llm.llm.fp16.zip'.format(model_dir),
|
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||||
'{}/flow.encoder.fp32.zip'.format(model_dir))
|
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
if load_onnx:
|
if load_trt:
|
||||||
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
del configs
|
del configs
|
||||||
|
|
||||||
def list_available_spks(self):
|
def list_available_spks(self):
|
||||||
@@ -123,9 +123,10 @@ class CosyVoice:
|
|||||||
|
|
||||||
class CosyVoice2(CosyVoice):
|
class CosyVoice2(CosyVoice):
|
||||||
|
|
||||||
def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
|
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
||||||
self.instruct = True if '-Instruct' in model_dir else False
|
self.instruct = True if '-Instruct' in model_dir else False
|
||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
|
self.fp16 = fp16
|
||||||
if not os.path.exists(model_dir):
|
if not os.path.exists(model_dir):
|
||||||
model_dir = snapshot_download(model_dir)
|
model_dir = snapshot_download(model_dir)
|
||||||
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
||||||
@@ -138,22 +139,17 @@ class CosyVoice2(CosyVoice):
|
|||||||
'{}/spk2info.pt'.format(model_dir),
|
'{}/spk2info.pt'.format(model_dir),
|
||||||
configs['allowed_special'])
|
configs['allowed_special'])
|
||||||
self.sample_rate = configs['sample_rate']
|
self.sample_rate = configs['sample_rate']
|
||||||
if torch.cuda.is_available() is False and load_jit is True:
|
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||||
load_jit = False
|
load_jit, load_trt, fp16 = False, False, False
|
||||||
logging.warning('cpu do not support jit, force set to False')
|
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
|
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||||
self.model.load('{}/llm.pt'.format(model_dir),
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
'{}/flow.pt'.format(model_dir),
|
'{}/flow.pt'.format(model_dir),
|
||||||
'{}/hift.pt'.format(model_dir))
|
'{}/hift.pt'.format(model_dir))
|
||||||
if load_jit:
|
if load_jit:
|
||||||
self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
|
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
if load_trt is True and load_onnx is True:
|
|
||||||
load_onnx = False
|
|
||||||
logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
|
|
||||||
if load_onnx:
|
|
||||||
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
|
||||||
if load_trt:
|
if load_trt:
|
||||||
self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
|
self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||||
del configs
|
del configs
|
||||||
|
|
||||||
def inference_instruct(self, *args, **kwargs):
|
def inference_instruct(self, *args, **kwargs):
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ class CosyVoiceModel:
|
|||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.hift = hift
|
self.hift = hift
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
|
self.llm.fp16 = fp16
|
||||||
|
self.flow.fp16 = fp16
|
||||||
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
||||||
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
||||||
self.token_overlap_len = 20
|
self.token_overlap_len = 20
|
||||||
@@ -61,17 +63,17 @@ class CosyVoiceModel:
|
|||||||
def load(self, llm_model, flow_model, hift_model):
|
def load(self, llm_model, flow_model, hift_model):
|
||||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||||
self.llm.to(self.device).eval()
|
self.llm.to(self.device).eval()
|
||||||
if self.fp16 is True:
|
|
||||||
self.llm.half()
|
|
||||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||||
self.flow.to(self.device).eval()
|
self.flow.to(self.device).eval()
|
||||||
# in case hift_model is a hifigan model
|
# in case hift_model is a hifigan model
|
||||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
||||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||||
self.hift.to(self.device).eval()
|
self.hift.to(self.device).eval()
|
||||||
|
if self.fp16 is True:
|
||||||
|
self.llm.half()
|
||||||
|
self.flow.half()
|
||||||
|
|
||||||
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
||||||
assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
|
|
||||||
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
||||||
self.llm.text_encoder = llm_text_encoder
|
self.llm.text_encoder = llm_text_encoder
|
||||||
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
||||||
@@ -79,18 +81,16 @@ class CosyVoiceModel:
|
|||||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
self.flow.encoder = flow_encoder
|
self.flow.encoder = flow_encoder
|
||||||
|
|
||||||
def load_onnx(self, flow_decoder_estimator_model):
|
def load_trt(self, flow_decoder_estimator_model):
|
||||||
import onnxruntime
|
|
||||||
option = onnxruntime.SessionOptions()
|
|
||||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
option.intra_op_num_threads = 1
|
|
||||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
||||||
del self.flow.decoder.estimator
|
del self.flow.decoder.estimator
|
||||||
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
import tensorrt as trt
|
||||||
|
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||||
|
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||||
|
if self.flow.decoder.estimator_engine is None:
|
||||||
|
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
||||||
|
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
||||||
|
|
||||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||||
if self.fp16 is True:
|
|
||||||
llm_embedding = llm_embedding.half()
|
|
||||||
with self.llm_context:
|
with self.llm_context:
|
||||||
for i in self.llm.inference(text=text.to(self.device),
|
for i in self.llm.inference(text=text.to(self.device),
|
||||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
@@ -259,16 +259,20 @@ class CosyVoiceModel:
|
|||||||
self.hift_cache_dict.pop(this_uuid)
|
self.hift_cache_dict.pop(this_uuid)
|
||||||
|
|
||||||
|
|
||||||
class CosyVoice2Model:
|
class CosyVoice2Model(CosyVoiceModel):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
llm: torch.nn.Module,
|
llm: torch.nn.Module,
|
||||||
flow: torch.nn.Module,
|
flow: torch.nn.Module,
|
||||||
hift: torch.nn.Module):
|
hift: torch.nn.Module,
|
||||||
|
fp16: bool):
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.hift = hift
|
self.hift = hift
|
||||||
|
self.fp16 = fp16
|
||||||
|
self.llm.fp16 = fp16
|
||||||
|
self.flow.fp16 = fp16
|
||||||
self.token_hop_len = 2 * self.flow.input_frame_rate
|
self.token_hop_len = 2 * self.flow.input_frame_rate
|
||||||
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
|
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
|
||||||
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
|
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
|
||||||
@@ -287,52 +291,10 @@ class CosyVoice2Model:
|
|||||||
self.llm_end_dict = {}
|
self.llm_end_dict = {}
|
||||||
self.hift_cache_dict = {}
|
self.hift_cache_dict = {}
|
||||||
|
|
||||||
def load(self, llm_model, flow_model, hift_model):
|
|
||||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
|
||||||
self.llm.to(self.device).eval()
|
|
||||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
|
||||||
self.flow.to(self.device).eval()
|
|
||||||
self.flow.decoder.fp16 = False
|
|
||||||
# in case hift_model is a hifigan model
|
|
||||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
|
||||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
|
||||||
self.hift.to(self.device).eval()
|
|
||||||
|
|
||||||
def load_jit(self, flow_encoder_model):
|
def load_jit(self, flow_encoder_model):
|
||||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
self.flow.encoder = flow_encoder
|
self.flow.encoder = flow_encoder
|
||||||
|
|
||||||
def load_onnx(self, flow_decoder_estimator_model):
|
|
||||||
import onnxruntime
|
|
||||||
option = onnxruntime.SessionOptions()
|
|
||||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
option.intra_op_num_threads = 1
|
|
||||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
||||||
del self.flow.decoder.estimator
|
|
||||||
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
|
||||||
|
|
||||||
def load_trt(self, flow_decoder_estimator_model):
|
|
||||||
del self.flow.decoder.estimator
|
|
||||||
import tensorrt as trt
|
|
||||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
|
||||||
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
|
||||||
if self.flow.decoder.estimator_engine is None:
|
|
||||||
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
|
||||||
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
|
||||||
self.flow.decoder.fp16 = True
|
|
||||||
|
|
||||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
|
||||||
with self.llm_context:
|
|
||||||
for i in self.llm.inference(text=text.to(self.device),
|
|
||||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
|
||||||
prompt_text=prompt_text.to(self.device),
|
|
||||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
|
||||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
|
||||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
||||||
embedding=llm_embedding.to(self.device)):
|
|
||||||
self.tts_speech_token_dict[uuid].append(i)
|
|
||||||
self.llm_end_dict[uuid] = True
|
|
||||||
|
|
||||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
||||||
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
|||||||
@@ -111,6 +111,10 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
prompt_feat_len,
|
prompt_feat_len,
|
||||||
embedding,
|
embedding,
|
||||||
flow_cache):
|
flow_cache):
|
||||||
|
if self.fp16 is True:
|
||||||
|
prompt_feat = prompt_feat.half()
|
||||||
|
embedding = embedding.half()
|
||||||
|
|
||||||
assert token.shape[0] == 1
|
assert token.shape[0] == 1
|
||||||
# xvec projection
|
# xvec projection
|
||||||
embedding = F.normalize(embedding, dim=1)
|
embedding = F.normalize(embedding, dim=1)
|
||||||
@@ -129,7 +133,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
||||||
|
|
||||||
# get conditions
|
# get conditions
|
||||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||||
conds[:, :mel_len1] = prompt_feat
|
conds[:, :mel_len1] = prompt_feat
|
||||||
conds = conds.transpose(1, 2)
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
@@ -145,7 +149,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
feat = feat[:, :, mel_len1:]
|
feat = feat[:, :, mel_len1:]
|
||||||
assert feat.shape[2] == mel_len2
|
assert feat.shape[2] == mel_len2
|
||||||
return feat, flow_cache
|
return feat.float(), flow_cache
|
||||||
|
|
||||||
|
|
||||||
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||||
@@ -196,6 +200,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
prompt_feat_len,
|
prompt_feat_len,
|
||||||
embedding,
|
embedding,
|
||||||
finalize):
|
finalize):
|
||||||
|
if self.fp16 is True:
|
||||||
|
prompt_feat = prompt_feat.half()
|
||||||
|
embedding = embedding.half()
|
||||||
|
|
||||||
assert token.shape[0] == 1
|
assert token.shape[0] == 1
|
||||||
# xvec projection
|
# xvec projection
|
||||||
embedding = F.normalize(embedding, dim=1)
|
embedding = F.normalize(embedding, dim=1)
|
||||||
@@ -214,7 +222,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
h = self.encoder_proj(h)
|
h = self.encoder_proj(h)
|
||||||
|
|
||||||
# get conditions
|
# get conditions
|
||||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||||
conds[:, :mel_len1] = prompt_feat
|
conds[:, :mel_len1] = prompt_feat
|
||||||
conds = conds.transpose(1, 2)
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
@@ -228,4 +236,4 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
feat = feat[:, :, mel_len1:]
|
feat = feat[:, :, mel_len1:]
|
||||||
assert feat.shape[2] == mel_len2
|
assert feat.shape[2] == mel_len2
|
||||||
return feat, None
|
return feat.float(), None
|
||||||
|
|||||||
@@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import onnxruntime
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from matcha.models.components.flow_matching import BASECFM
|
from matcha.models.components.flow_matching import BASECFM
|
||||||
@@ -52,7 +51,7 @@ class ConditionalCFM(BASECFM):
|
|||||||
shape: (batch_size, n_feats, mel_timesteps)
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
z = torch.randn_like(mu) * temperature
|
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
||||||
cache_size = flow_cache.shape[2]
|
cache_size = flow_cache.shape[2]
|
||||||
# fix prompt and overlap part mu and z
|
# fix prompt and overlap part mu and z
|
||||||
if cache_size != 0:
|
if cache_size != 0:
|
||||||
@@ -89,36 +88,29 @@ class ConditionalCFM(BASECFM):
|
|||||||
# Or in future might add like a return_all_steps flag
|
# Or in future might add like a return_all_steps flag
|
||||||
sol = []
|
sol = []
|
||||||
|
|
||||||
if self.inference_cfg_rate > 0:
|
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
|
||||||
else:
|
|
||||||
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
|
||||||
for step in range(1, len(t_span)):
|
for step in range(1, len(t_span)):
|
||||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||||
if self.inference_cfg_rate > 0:
|
x_in[:] = x
|
||||||
x_in[:] = x
|
mask_in[:] = mask
|
||||||
mask_in[:] = mask
|
mu_in[0] = mu
|
||||||
mu_in[0] = mu
|
t_in[:] = t.unsqueeze(0)
|
||||||
t_in[:] = t.unsqueeze(0)
|
spks_in[0] = spks
|
||||||
spks_in[0] = spks
|
cond_in[0] = cond
|
||||||
cond_in[0] = cond
|
|
||||||
else:
|
|
||||||
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
|
||||||
dphi_dt = self.forward_estimator(
|
dphi_dt = self.forward_estimator(
|
||||||
x_in, mask_in,
|
x_in, mask_in,
|
||||||
mu_in, t_in,
|
mu_in, t_in,
|
||||||
spks_in,
|
spks_in,
|
||||||
cond_in
|
cond_in
|
||||||
)
|
)
|
||||||
if self.inference_cfg_rate > 0:
|
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
||||||
x = x + dt * dphi_dt
|
x = x + dt * dphi_dt
|
||||||
t = t + dt
|
t = t + dt
|
||||||
sol.append(x)
|
sol.append(x)
|
||||||
@@ -130,17 +122,6 @@ class ConditionalCFM(BASECFM):
|
|||||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||||
if isinstance(self.estimator, torch.nn.Module):
|
if isinstance(self.estimator, torch.nn.Module):
|
||||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||||
elif isinstance(self.estimator, onnxruntime.InferenceSession):
|
|
||||||
ort_inputs = {
|
|
||||||
'x': x.cpu().numpy(),
|
|
||||||
'mask': mask.cpu().numpy(),
|
|
||||||
'mu': mu.cpu().numpy(),
|
|
||||||
't': t.cpu().numpy(),
|
|
||||||
'spks': spks.cpu().numpy(),
|
|
||||||
'cond': cond.cpu().numpy()
|
|
||||||
}
|
|
||||||
output = self.estimator.run(None, ort_inputs)[0]
|
|
||||||
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
|
||||||
else:
|
else:
|
||||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||||
@@ -225,9 +206,7 @@ class CausalConditionalCFM(ConditionalCFM):
|
|||||||
shape: (batch_size, n_feats, mel_timesteps)
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
||||||
if self.fp16 is True:
|
|
||||||
z = z.half()
|
|
||||||
# fix prompt and overlap part mu and z
|
# fix prompt and overlap part mu and z
|
||||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||||
if self.t_scheduler == 'cosine':
|
if self.t_scheduler == 'cosine':
|
||||||
|
|||||||
@@ -164,6 +164,9 @@ class TransformerLM(torch.nn.Module):
|
|||||||
max_token_text_ratio: float = 20,
|
max_token_text_ratio: float = 20,
|
||||||
min_token_text_ratio: float = 2,
|
min_token_text_ratio: float = 2,
|
||||||
) -> Generator[torch.Tensor, None, None]:
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
if self.fp16 is True:
|
||||||
|
embedding = embedding.half()
|
||||||
|
|
||||||
device = text.device
|
device = text.device
|
||||||
text = torch.concat([prompt_text, text], dim=1)
|
text = torch.concat([prompt_text, text], dim=1)
|
||||||
text_len += prompt_text_len
|
text_len += prompt_text_len
|
||||||
@@ -178,7 +181,7 @@ class TransformerLM(torch.nn.Module):
|
|||||||
embedding = self.spk_embed_affine_layer(embedding)
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
embedding = embedding.unsqueeze(dim=1)
|
embedding = embedding.unsqueeze(dim=1)
|
||||||
else:
|
else:
|
||||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
||||||
|
|
||||||
# 3. concat llm_input
|
# 3. concat llm_input
|
||||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||||
@@ -310,7 +313,7 @@ class Qwen2LM(torch.nn.Module):
|
|||||||
text = self.llm.model.model.embed_tokens(text)
|
text = self.llm.model.model.embed_tokens(text)
|
||||||
|
|
||||||
# 2. encode embedding
|
# 2. encode embedding
|
||||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
||||||
|
|
||||||
# 3. concat llm_input
|
# 3. concat llm_input
|
||||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ COSYVOICE_ATTENTION_CLASSES = {
|
|||||||
|
|
||||||
|
|
||||||
def get_model_type(configs):
|
def get_model_type(configs):
|
||||||
|
# NOTE CosyVoice2Model inherits CosyVoiceModel
|
||||||
if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||||
return CosyVoiceModel
|
return CosyVoiceModel
|
||||||
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ def subsequent_mask(
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def subsequent_chunk_mask(
|
def subsequent_chunk_mask_deprecated(
|
||||||
size: int,
|
size: int,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
num_left_chunks: int = -1,
|
num_left_chunks: int = -1,
|
||||||
@@ -124,6 +124,41 @@ def subsequent_chunk_mask(
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def subsequent_chunk_mask(
|
||||||
|
size: int,
|
||||||
|
chunk_size: int,
|
||||||
|
num_left_chunks: int = -1,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||||
|
this is for streaming encoder
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (int): size of mask
|
||||||
|
chunk_size (int): size of chunk
|
||||||
|
num_left_chunks (int): number of left chunks
|
||||||
|
<0: use full chunk
|
||||||
|
>=0: use num_left_chunks
|
||||||
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: mask
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> subsequent_chunk_mask(4, 2)
|
||||||
|
[[1, 1, 0, 0],
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1]]
|
||||||
|
"""
|
||||||
|
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
||||||
|
# actually this is not needed after we have inference cache implemented, will remove it later
|
||||||
|
pos_idx = torch.arange(size, device=device)
|
||||||
|
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
||||||
|
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def add_optional_chunk_mask(xs: torch.Tensor,
|
def add_optional_chunk_mask(xs: torch.Tensor,
|
||||||
masks: torch.Tensor,
|
masks: torch.Tensor,
|
||||||
use_dynamic_chunk: bool,
|
use_dynamic_chunk: bool,
|
||||||
|
|||||||
Reference in New Issue
Block a user