update stream code

This commit is contained in:
lyuxiang.lx
2024-07-30 16:11:28 +08:00
parent 02f941d348
commit f4e70e222c
15 changed files with 182 additions and 109 deletions

2
.gitignore vendored
View File

@@ -43,6 +43,8 @@ compile_commands.json
# train/inference files # train/inference files
*.wav *.wav
*.m4a
*.aac
*.pt *.pt
pretrained_models/* pretrained_models/*
*_pb2_grpc.py *_pb2_grpc.py

View File

@@ -86,23 +86,24 @@ import torchaudio
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT') cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
# sft usage # sft usage
print(cosyvoice.list_avaliable_spks()) print(cosyvoice.list_avaliable_spks())
output = cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女') # change stream=True for chunk stream inference
torchaudio.save('sft.wav', output['tts_speech'], 22050) for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M') cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000) prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k) for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
torchaudio.save('zero_shot.wav', output['tts_speech'], 22050) torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
# cross_lingual usage # cross_lingual usage
prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000) prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k) for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050) torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct') cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath] # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.') for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
torchaudio.save('instruct.wav', output['tts_speech'], 22050) torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], 22050)
``` ```
**Start web demo** **Start web demo**
@@ -133,10 +134,10 @@ docker build -t cosyvoice:v1.0 .
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
# for grpc usage # for grpc usage
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity" docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
python3 grpc/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct> cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
# for fastapi usage # for fastapi usage
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity" docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
python3 fastapi/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct> cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
``` ```
## Discussion & Communication ## Discussion & Communication

View File

@@ -100,10 +100,13 @@ def main():
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
model_output = model.inference(**model_input) tts_speeches = []
for model_output in model.inference(**model_input):
tts_speeches.append(model_output['tts_speech'])
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0]) tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050) torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
f.write('{} {}\n'.format(tts_key, tts_fn)) f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush() f.flush()
f.close() f.close()

View File

@@ -49,6 +49,7 @@ class CosyVoice:
for i in self.frontend.text_normalize(tts_text, split=True): for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_sft(i, spk_id) model_input = self.frontend.frontend_sft(i, spk_id)
start_time = time.time() start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream): for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050 speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -60,6 +61,7 @@ class CosyVoice:
for i in self.frontend.text_normalize(tts_text, split=True): for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k) model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
start_time = time.time() start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream): for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050 speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -72,6 +74,7 @@ class CosyVoice:
for i in self.frontend.text_normalize(tts_text, split=True): for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k) model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
start_time = time.time() start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream): for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050 speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -85,6 +88,7 @@ class CosyVoice:
for i in self.frontend.text_normalize(tts_text, split=True): for i in self.frontend.text_normalize(tts_text, split=True):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
start_time = time.time() start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream): for model_output in self.model.inference(**model_input, stream=stream):
speech_len = model_output['tts_speech'].shape[1] / 22050 speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))

View File

@@ -16,6 +16,8 @@ import numpy as np
import threading import threading
import time import time
from contextlib import nullcontext from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
class CosyVoiceModel: class CosyVoiceModel:
@@ -28,13 +30,19 @@ class CosyVoiceModel:
self.llm = llm self.llm = llm
self.flow = flow self.flow = flow
self.hift = hift self.hift = hift
self.stream_win_len = 60 * 4 self.token_min_hop_len = 100
self.stream_hop_len = 50 * 4 self.token_max_hop_len = 400
self.overlap = 4395 * 4 # 10 token equals 4395 sample point self.token_overlap_len = 20
self.window = np.hamming(2 * self.overlap) self.speech_overlap_len = 34 * 256
self.window = np.hamming(2 * self.speech_overlap_len)
self.stream_scale_factor = 1
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock() self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token = {}
self.llm_end = {}
def load(self, llm_model, flow_model, hift_model): def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device)) self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -44,7 +52,7 @@ class CosyVoiceModel:
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device)) self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
self.hift.to(self.device).eval() self.hift.to(self.device).eval()
def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding): def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
with self.llm_context: with self.llm_context:
for i in self.llm.inference(text=text.to(self.device), for i in self.llm.inference(text=text.to(self.device),
text_len=text_len.to(self.device), text_len=text_len.to(self.device),
@@ -53,13 +61,11 @@ class CosyVoiceModel:
prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device), prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
embedding=llm_embedding.to(self.device), embedding=llm_embedding.to(self.device),
beam_size=1,
sampling=25, sampling=25,
max_token_text_ratio=30, max_token_text_ratio=30,
min_token_text_ratio=3, min_token_text_ratio=3):
stream=True): self.tts_speech_token[this_uuid].append(i)
self.tts_speech_token.append(i) self.llm_end[this_uuid] = True
self.llm_end = True
def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding): def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding):
with self.flow_hift_context: with self.flow_hift_context:
@@ -78,15 +84,19 @@ class CosyVoiceModel:
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32), flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False): prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token[this_uuid], self.llm_end[this_uuid] = [], False
p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device), this_uuid))
p.start()
if stream is True: if stream is True:
self.tts_speech_token, self.llm_end, cache_speech = [], False, None cache_speech, cache_token, token_hop_len = None, None, self.token_min_hop_len
p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device)))
p.start()
while True: while True:
time.sleep(0.1) time.sleep(0.1)
if len(self.tts_speech_token) >= self.stream_win_len: if len(self.tts_speech_token[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_tts_speech_token = torch.concat(self.tts_speech_token[:self.stream_win_len], dim=1) this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
with self.flow_hift_context: with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token, this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token.to(self.device), prompt_token=flow_prompt_speech_token.to(self.device),
@@ -96,57 +106,48 @@ class CosyVoiceModel:
embedding=flow_embedding.to(self.device)) embedding=flow_embedding.to(self.device))
# fade in/out if necessary # fade in/out if necessary
if cache_speech is not None: if cache_speech is not None:
this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:] this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
yield {'tts_speech': this_tts_speech[:, :-self.overlap]} yield {'tts_speech': this_tts_speech[:, :-self.speech_overlap_len]}
cache_speech = this_tts_speech[:, -self.overlap:] cache_speech = this_tts_speech[:, -self.speech_overlap_len:]
cache_token = self.tts_speech_token[this_uuid][:token_hop_len]
with self.lock: with self.lock:
self.tts_speech_token = self.tts_speech_token[self.stream_hop_len:] self.tts_speech_token[this_uuid] = self.tts_speech_token[this_uuid][token_hop_len:]
if self.llm_end is True: # increase token_hop_len for better speech quality
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end[this_uuid] is True and len(self.tts_speech_token[this_uuid]) < token_hop_len + self.token_overlap_len:
break break
# deal with remain tokens
if cache_speech is None or len(self.tts_speech_token) > self.stream_win_len - self.stream_hop_len:
this_tts_speech_token = torch.concat(self.tts_speech_token, dim=1)
with self.flow_hift_context:
this_tts_mel = self.flow.inference(token=this_tts_speech_token,
token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
prompt_token=flow_prompt_speech_token.to(self.device),
prompt_token_len=flow_prompt_speech_token_len.to(self.device),
prompt_feat=prompt_speech_feat.to(self.device),
prompt_feat_len=prompt_speech_feat_len.to(self.device),
embedding=flow_embedding.to(self.device))
this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu()
if cache_speech is not None:
this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
yield {'tts_speech': this_tts_speech}
else:
assert len(self.tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len)
yield {'tts_speech': cache_speech}
p.join() p.join()
torch.cuda.synchronize() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
if this_tts_speech_token.shape[1] < self.token_min_hop_len + self.token_overlap_len and cache_token is not None:
cache_token_len = self.token_min_hop_len + self.token_overlap_len - this_tts_speech_token.shape[1]
this_tts_speech_token = torch.concat([torch.concat(cache_token[-cache_token_len:], dim=1), this_tts_speech_token], dim=1)
else:
cache_token_len = 0
with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token.to(self.device),
prompt_token_len=flow_prompt_speech_token_len.to(self.device),
prompt_feat=prompt_speech_feat.to(self.device),
prompt_feat_len=prompt_speech_feat_len.to(self.device),
embedding=flow_embedding.to(self.device))
this_tts_speech = this_tts_speech[:, int(cache_token_len / this_tts_speech_token.shape[1] * this_tts_speech.shape[1]):]
if cache_speech is not None:
this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
yield {'tts_speech': this_tts_speech}
else: else:
tts_speech_token = [] # deal with all tokens
for i in self.llm.inference(text=text.to(self.device), p.join()
text_len=text_len.to(self.device), this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
prompt_text=prompt_text.to(self.device), with self.flow_hift_context:
prompt_text_len=prompt_text_len.to(self.device), this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_speech_token=llm_prompt_speech_token.to(self.device), prompt_token=flow_prompt_speech_token.to(self.device),
prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device), prompt_token_len=flow_prompt_speech_token_len.to(self.device),
embedding=llm_embedding.to(self.device), prompt_feat=prompt_speech_feat.to(self.device),
beam_size=1, prompt_feat_len=prompt_speech_feat_len.to(self.device),
sampling=25, embedding=flow_embedding.to(self.device))
max_token_text_ratio=30, yield {'tts_speech': this_tts_speech}
min_token_text_ratio=3, with self.lock:
stream=stream): self.tts_speech_token.pop(this_uuid)
tts_speech_token.append(i) self.llm_end.pop(this_uuid)
assert len(tts_speech_token) == 1, 'tts_speech_token len should be 1 when stream is {}'.format(stream) torch.cuda.synchronize()
tts_speech_token = torch.concat(tts_speech_token, dim=1)
tts_mel = self.flow.inference(token=tts_speech_token,
token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
prompt_token=flow_prompt_speech_token.to(self.device),
prompt_token_len=flow_prompt_speech_token_len.to(self.device),
prompt_feat=prompt_speech_feat.to(self.device),
prompt_feat_len=prompt_speech_feat_len.to(self.device),
embedding=flow_embedding.to(self.device))
tts_speech = self.hift.inference(mel=tts_mel).cpu()
torch.cuda.empty_cache()
yield {'tts_speech': tts_speech}

View File

@@ -105,6 +105,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
embedding = self.spk_embed_affine_layer(embedding) embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text # concat text and prompt_text
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding) mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask token = self.input_embedding(torch.clamp(token, min=0)) * mask
@@ -112,17 +113,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
# text encode # text encode
h, h_lengths = self.encoder(token, token_len) h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h) h = self.encoder_proj(h)
feat_len = (token_len / 50 * 22050 / 256).int() mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
h, h_lengths = self.length_regulator(h, feat_len) h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
# get conditions # get conditions
conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device) conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
if prompt_feat.shape[1] != 0: conds[:, :mel_len1] = prompt_feat
for i, j in enumerate(prompt_feat_len):
conds[i, :j] = prompt_feat[i]
conds = conds.transpose(1, 2) conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h) # mask = (~make_pad_mask(feat_len)).to(h)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat = self.decoder( feat = self.decoder(
mu=h.transpose(1, 2).contiguous(), mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1), mask=mask.unsqueeze(1),
@@ -130,6 +130,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
cond=conds, cond=conds,
n_timesteps=10 n_timesteps=10
) )
if prompt_feat.shape[1] != 0: feat = feat[:, :, mel_len1:]
feat = feat[:, :, prompt_feat.shape[1]:] assert feat.shape[2] == mel_len2
return feat return feat

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Tuple from typing import Tuple
import torch.nn as nn import torch.nn as nn
import torch
from torch.nn import functional as F from torch.nn import functional as F
from cosyvoice.utils.mask import make_pad_mask from cosyvoice.utils.mask import make_pad_mask
@@ -47,3 +48,21 @@ class InterpolateRegulator(nn.Module):
out = self.model(x).transpose(1, 2).contiguous() out = self.model(x).transpose(1, 2).contiguous()
olens = ylens olens = ylens
return out * mask, olens return out * mask, olens
def inference(self, x1, x2, mel_len1, mel_len2):
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
# x in (B, T, D)
if x2.shape[1] > 40:
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
else:
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
if x1.shape[1] != 0:
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
x = torch.concat([x1, x2], dim=2)
else:
x = x2
out = self.model(x).transpose(1, 2).contiguous()
return out, mel_len1 + mel_len2

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Optional, Union from typing import Dict, Optional, Callable, List, Generator
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
speech_token_size: int, speech_token_size: int,
text_encoder: torch.nn.Module, text_encoder: torch.nn.Module,
llm: torch.nn.Module, llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True, length_normalized_loss: bool = True,
lsm_weight: float = 0.0, lsm_weight: float = 0.0,
spk_embed_dim: int = 192, spk_embed_dim: int = 192,
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size) self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size) self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
# 4. sampling method
self.sampling = sampling
def encode( def encode(
self, self,
text: torch.Tensor, text: torch.Tensor,
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
def sampling_ids( def sampling_ids(
self, self,
weighted_scores: torch.Tensor, weighted_scores: torch.Tensor,
sampling: Union[bool, int, float] = True, decoded_tokens: List,
beam_size: int = 1, sampling: int,
ignore_eos: bool = True, ignore_eos: bool = True,
): ):
while True: while True:
prob, indices = weighted_scores.softmax(dim=-1).topk(sampling) top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
top_ids = prob.multinomial(beam_size, replacement=True)
top_ids = indices[top_ids]
if (not ignore_eos) or (self.speech_token_size not in top_ids): if (not ignore_eos) or (self.speech_token_size not in top_ids):
break break
return top_ids return top_ids
@@ -154,12 +156,10 @@ class TransformerLM(torch.nn.Module):
prompt_speech_token: torch.Tensor, prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor, prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor, embedding: torch.Tensor,
beam_size: int = 1,
sampling: int = 25, sampling: int = 25,
max_token_text_ratio: float = 20, max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2, min_token_text_ratio: float = 2,
stream: bool = False, ) -> Generator[torch.Tensor, None, None]:
) -> torch.Tensor:
device = text.device device = text.device
text = torch.concat([prompt_text, text], dim=1) text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len text_len += prompt_text_len
@@ -197,16 +197,11 @@ class TransformerLM(torch.nn.Module):
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache, y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool)) att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item() top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size: if top_ids == self.speech_token_size:
break break
# in stream mode, yield token one by one # in stream mode, yield token one by one
if stream is True: yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
out_tokens.append(top_ids) out_tokens.append(top_ids)
offset += lm_input.size(1) offset += lm_input.size(1)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
# in non-stream mode, yield all token
if stream is False:
yield torch.tensor([out_tokens], dtype=torch.int64, device=device)

View File

@@ -101,3 +101,37 @@ def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__ classname = m.__class__.__name__
if classname.find("Conv") != -1: if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std) m.weight.data.normal_(mean, std)
# Repetition Aware Sampling in VALL-E 2
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
if rep_num >= win_size * tau_r:
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
return top_ids
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
prob, indices = [], []
cum_prob = 0.0
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
for i in range(len(sorted_idx)):
# sampling both top-p and numbers.
if cum_prob < top_p and len(prob) < top_k:
cum_prob += sorted_value[i]
prob.append(sorted_value[i])
indices.append(sorted_idx[i])
else:
break
prob = torch.tensor(prob).to(weighted_scores)
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
top_ids = indices[prob.multinomial(1, replacement=True)]
return top_ids
def random_sampling(weighted_scores, decoded_tokens, sampling):
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
return top_ids
def fade_in_out(fade_in_speech, fade_out_speech, window):
speech_overlap_len = int(window.shape[0] / 2)
fade_in_speech[:, :speech_overlap_len] = fade_in_speech[:, :speech_overlap_len] * window[:speech_overlap_len] + fade_out_speech[:, -speech_overlap_len:] * window[speech_overlap_len:]
return fade_in_speech

View File

@@ -54,6 +54,11 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
pos_enc_layer_type: 'rel_pos_espnet' pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn' selfattention_layer_type: 'rel_selfattn'
static_chunk_size: 1 static_chunk_size: 1
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
input_size: 512 input_size: 512

View File

@@ -54,6 +54,11 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
pos_enc_layer_type: 'rel_pos_espnet' pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn' selfattention_layer_type: 'rel_selfattn'
static_chunk_size: 1 static_chunk_size: 1
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
input_size: 512 input_size: 512

View File

@@ -61,8 +61,11 @@ def main():
request.instruct_request.CopyFrom(instruct_request) request.instruct_request.CopyFrom(instruct_request)
response = stub.Inference(request) response = stub.Inference(request)
tts_audio = b''
for r in response:
tts_audio += r.tts_audio
tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
logging.info('save response to {}'.format(args.tts_wav)) logging.info('save response to {}'.format(args.tts_wav))
tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0)
torchaudio.save(args.tts_wav, tts_speech, target_sr) torchaudio.save(args.tts_wav, tts_speech, target_sr)
logging.info('get response') logging.info('get response')

View File

@@ -4,7 +4,7 @@ package cosyvoice;
option go_package = "protos/"; option go_package = "protos/";
service CosyVoice{ service CosyVoice{
rpc Inference(Request) returns (Response) {} rpc Inference(Request) returns (stream Response) {}
} }
message Request{ message Request{

View File

@@ -54,9 +54,10 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text) model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
logging.info('send inference response') logging.info('send inference response')
response = cosyvoice_pb2.Response() for i in model_output:
response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() response = cosyvoice_pb2.Response()
return response response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
yield response
def main(): def main():
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc) grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)

View File

@@ -164,7 +164,7 @@ def main():
outputs=[audio_output]) outputs=[audio_output])
mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text]) mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
demo.queue(max_size=4, default_concurrency_limit=2) demo.queue(max_size=4, default_concurrency_limit=2)
demo.launch(server_port=args.port) demo.launch(server_name='0.0.0.0', server_port=args.port)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()