Merge remote-tracking branch 'origin/inference_streaming' into inference_streaming

This commit is contained in:
禾息
2024-09-03 11:13:25 +08:00
15 changed files with 754 additions and 6 deletions

View File

@@ -159,7 +159,6 @@ class CosyVoiceModel:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
p.join()
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
@@ -180,7 +179,7 @@ class CosyVoiceModel:
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
# p.join()
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
@@ -193,7 +192,7 @@ class CosyVoiceModel:
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
# p.join()
p.join()
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
with self.flow_hift_context:
this_tts_speech = self.token2wav(token=this_tts_speech_token,

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import Dict, Optional
import torch
import torch.nn as nn
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)

View File

@@ -82,10 +82,10 @@ class ConditionalCFM(BASECFM):
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = self.forward_estimator(
cfg_dphi_dt = self.estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,

View File

@@ -299,7 +299,7 @@ class BaseEncoder(torch.nn.Module):
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
with cache, which is not preferred.
Args:
xs (torch.Tensor): (1, max_len, dim)
chunk_size (int): decoding chunk size