mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
Merge remote-tracking branch 'origin/inference_streaming' into inference_streaming
This commit is contained in:
@@ -159,7 +159,6 @@ class CosyVoiceModel:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
p.start()
|
||||
p.join()
|
||||
if stream is True:
|
||||
token_hop_len = self.token_min_hop_len
|
||||
while True:
|
||||
@@ -180,7 +179,7 @@ class CosyVoiceModel:
|
||||
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
||||
break
|
||||
# p.join()
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
||||
with self.flow_hift_context:
|
||||
@@ -193,7 +192,7 @@ class CosyVoiceModel:
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
# p.join()
|
||||
p.join()
|
||||
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
||||
with self.flow_hift_context:
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :index] = feat[i, :index]
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(feat_len)).to(h)
|
||||
|
||||
@@ -82,10 +82,10 @@ class ConditionalCFM(BASECFM):
|
||||
sol = []
|
||||
|
||||
for step in range(1, len(t_span)):
|
||||
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
|
||||
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
if self.inference_cfg_rate > 0:
|
||||
cfg_dphi_dt = self.forward_estimator(
|
||||
cfg_dphi_dt = self.estimator(
|
||||
x, mask,
|
||||
torch.zeros_like(mu), t,
|
||||
torch.zeros_like(spks) if spks is not None else None,
|
||||
|
||||
@@ -299,7 +299,7 @@ class BaseEncoder(torch.nn.Module):
|
||||
rate.
|
||||
3. Currently, nn.Sequential is used to stack all the convolution
|
||||
layers in subsampling, we need to rewrite it to make it work
|
||||
with cache, which is not prefered.
|
||||
with cache, which is not preferred.
|
||||
Args:
|
||||
xs (torch.Tensor): (1, max_len, dim)
|
||||
chunk_size (int): decoding chunk size
|
||||
|
||||
Reference in New Issue
Block a user