mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
add cosyvoice2
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -48,4 +48,5 @@ compile_commands.json
|
|||||||
*.pt
|
*.pt
|
||||||
pretrained_models/*
|
pretrained_models/*
|
||||||
*_pb2_grpc.py
|
*_pb2_grpc.py
|
||||||
*_pb2.py
|
*_pb2.py
|
||||||
|
*.tar
|
||||||
@@ -18,7 +18,7 @@ from hyperpyyaml import load_hyperpyyaml
|
|||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
import torch
|
import torch
|
||||||
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
||||||
from cosyvoice.cli.model import CosyVoiceModel
|
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
||||||
from cosyvoice.utils.file_utils import logging
|
from cosyvoice.utils.file_utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -118,3 +118,35 @@ class CosyVoice:
|
|||||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||||
yield model_output
|
yield model_output
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
class CosyVoice2(CosyVoice):
|
||||||
|
|
||||||
|
def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
|
||||||
|
instruct = True if '-Instruct' in model_dir else False
|
||||||
|
self.model_dir = model_dir
|
||||||
|
if not os.path.exists(model_dir):
|
||||||
|
model_dir = snapshot_download(model_dir)
|
||||||
|
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
||||||
|
configs = load_hyperpyyaml(f)
|
||||||
|
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||||
|
configs['feat_extractor'],
|
||||||
|
'{}/campplus.onnx'.format(model_dir),
|
||||||
|
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
||||||
|
'{}/spk2info.pt'.format(model_dir),
|
||||||
|
instruct,
|
||||||
|
configs['allowed_special'])
|
||||||
|
if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
|
||||||
|
load_jit = False
|
||||||
|
fp16 = False
|
||||||
|
logging.warning('cpu do not support fp16 and jit, force set to False')
|
||||||
|
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||||
|
self.model.load('{}/llm.pt'.format(model_dir),
|
||||||
|
'{}/flow.pt'.format(model_dir),
|
||||||
|
'{}/hift.pt'.format(model_dir))
|
||||||
|
if load_jit:
|
||||||
|
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
||||||
|
'{}/llm.llm.fp16.zip'.format(model_dir),
|
||||||
|
'{}/flow.encoder.fp32.zip'.format(model_dir))
|
||||||
|
if load_onnx:
|
||||||
|
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
||||||
|
del configs
|
||||||
@@ -57,15 +57,15 @@ class CosyVoiceModel:
|
|||||||
self.hift_cache_dict = {}
|
self.hift_cache_dict = {}
|
||||||
|
|
||||||
def load(self, llm_model, flow_model, hift_model):
|
def load(self, llm_model, flow_model, hift_model):
|
||||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||||
self.llm.to(self.device).eval()
|
self.llm.to(self.device).eval()
|
||||||
if self.fp16 is True:
|
if self.fp16 is True:
|
||||||
self.llm.half()
|
self.llm.half()
|
||||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||||
self.flow.to(self.device).eval()
|
self.flow.to(self.device).eval()
|
||||||
# in case hift_model is a hifigan model
|
# in case hift_model is a hifigan model
|
||||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
||||||
self.hift.load_state_dict(hift_state_dict, strict=False)
|
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||||
self.hift.to(self.device).eval()
|
self.hift.to(self.device).eval()
|
||||||
|
|
||||||
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
||||||
@@ -254,3 +254,175 @@ class CosyVoiceModel:
|
|||||||
self.llm_end_dict.pop(this_uuid)
|
self.llm_end_dict.pop(this_uuid)
|
||||||
self.mel_overlap_dict.pop(this_uuid)
|
self.mel_overlap_dict.pop(this_uuid)
|
||||||
self.hift_cache_dict.pop(this_uuid)
|
self.hift_cache_dict.pop(this_uuid)
|
||||||
|
|
||||||
|
|
||||||
|
class CosyVoice2Model:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
llm: torch.nn.Module,
|
||||||
|
flow: torch.nn.Module,
|
||||||
|
hift: torch.nn.Module,
|
||||||
|
fp16: bool):
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.llm = llm
|
||||||
|
self.flow = flow
|
||||||
|
self.hift = hift
|
||||||
|
self.fp16 = fp16
|
||||||
|
self.token_min_hop_len = 1 * self.flow.input_frame_rate
|
||||||
|
self.token_max_hop_len = 2 * self.flow.input_frame_rate
|
||||||
|
self.token_right_context = self.flow.encoder.pre_lookahead_layer.pre_lookahead_len
|
||||||
|
# hift cache
|
||||||
|
self.mel_cache_len = 8
|
||||||
|
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||||
|
# speech fade in out
|
||||||
|
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||||
|
# rtf and decoding related
|
||||||
|
self.stream_scale_factor = 1
|
||||||
|
assert self.stream_scale_factor == 1, 'fix stream_scale_factor to 1 as we haven\'t implement cache in flow matching yet, this constraint will be loosen in the future'
|
||||||
|
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
# dict used to store session related variable
|
||||||
|
self.tts_speech_token_dict = {}
|
||||||
|
self.llm_end_dict = {}
|
||||||
|
self.hift_cache_dict = {}
|
||||||
|
|
||||||
|
def load(self, llm_model, flow_model, hift_model):
|
||||||
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
||||||
|
self.llm.to(self.device).eval()
|
||||||
|
if self.fp16 is True:
|
||||||
|
self.llm.half()
|
||||||
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
||||||
|
self.flow.to(self.device).eval()
|
||||||
|
# in case hift_model is a hifigan model
|
||||||
|
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
||||||
|
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||||
|
self.hift.to(self.device).eval()
|
||||||
|
|
||||||
|
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
||||||
|
assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
|
||||||
|
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
||||||
|
self.llm.text_encoder = llm_text_encoder
|
||||||
|
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
||||||
|
self.llm.llm = llm_llm
|
||||||
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||||
|
self.flow.encoder = flow_encoder
|
||||||
|
|
||||||
|
def load_onnx(self, flow_decoder_estimator_model):
|
||||||
|
import onnxruntime
|
||||||
|
option = onnxruntime.SessionOptions()
|
||||||
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
option.intra_op_num_threads = 1
|
||||||
|
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||||
|
del self.flow.decoder.estimator
|
||||||
|
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
||||||
|
|
||||||
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||||
|
if self.fp16 is True:
|
||||||
|
llm_embedding = llm_embedding.half()
|
||||||
|
with self.llm_context:
|
||||||
|
for i in self.llm.inference(text=text.to(self.device),
|
||||||
|
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_text=prompt_text.to(self.device),
|
||||||
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||||
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=llm_embedding.to(self.device)):
|
||||||
|
self.tts_speech_token_dict[uuid].append(i)
|
||||||
|
self.llm_end_dict[uuid] = True
|
||||||
|
|
||||||
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
||||||
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
||||||
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_token=prompt_token.to(self.device),
|
||||||
|
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
prompt_feat=prompt_feat.to(self.device),
|
||||||
|
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||||
|
embedding=embedding.to(self.device),
|
||||||
|
finalize=finalize)
|
||||||
|
tts_mel = tts_mel[:, :, token_offset * self.flow.encoder.up_layer.stride:]
|
||||||
|
# append hift cache
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||||
|
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||||
|
else:
|
||||||
|
hift_cache_source = torch.zeros(1, 1, 0)
|
||||||
|
# keep overlap mel and hift cache
|
||||||
|
if finalize is False:
|
||||||
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
|
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||||
|
'source': tts_source[:, :, -self.source_cache_len:],
|
||||||
|
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||||
|
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||||
|
else:
|
||||||
|
if speed != 1.0:
|
||||||
|
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||||
|
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||||
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||||
|
if self.hift_cache_dict[uuid] is not None:
|
||||||
|
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||||
|
return tts_speech
|
||||||
|
|
||||||
|
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
||||||
|
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||||
|
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||||
|
# this_uuid is used to track variables related to this inference thread
|
||||||
|
this_uuid = str(uuid.uuid1())
|
||||||
|
with self.lock:
|
||||||
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||||
|
self.hift_cache_dict[this_uuid] = None
|
||||||
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||||
|
p.start()
|
||||||
|
if stream is True:
|
||||||
|
token_hop_len, token_offset = self.token_min_hop_len, 0
|
||||||
|
self.flow.encoder.static_chunk_size = self.token_min_hop_len
|
||||||
|
self.flow.decoder.estimator.static_chunk_size = self.token_min_hop_len * self.flow.encoder.up_layer.stride
|
||||||
|
while True:
|
||||||
|
time.sleep(0.1)
|
||||||
|
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= token_hop_len + self.token_right_context:
|
||||||
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + token_hop_len + self.token_right_context]) \
|
||||||
|
.unsqueeze(dim=0)
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
uuid=this_uuid,
|
||||||
|
token_offset=token_offset,
|
||||||
|
finalize=False)
|
||||||
|
token_offset += token_hop_len
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
# increase token_hop_len for better speech quality
|
||||||
|
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||||
|
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < token_hop_len + self.token_right_context:
|
||||||
|
break
|
||||||
|
p.join()
|
||||||
|
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||||
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
uuid=this_uuid,
|
||||||
|
token_offset=token_offset,
|
||||||
|
finalize=True)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
else:
|
||||||
|
# deal with all tokens
|
||||||
|
p.join()
|
||||||
|
self.flow.encoder.static_chunk_size = 0
|
||||||
|
self.flow.decoder.estimator.static_chunk_size = 0
|
||||||
|
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||||
|
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||||
|
prompt_token=flow_prompt_speech_token,
|
||||||
|
prompt_feat=prompt_speech_feat,
|
||||||
|
embedding=flow_embedding,
|
||||||
|
uuid=this_uuid,
|
||||||
|
finalize=True,
|
||||||
|
speed=speed)
|
||||||
|
yield {'tts_speech': this_tts_speech.cpu()}
|
||||||
|
with self.lock:
|
||||||
|
self.tts_speech_token_dict.pop(this_uuid)
|
||||||
|
self.llm_end_dict.pop(this_uuid)
|
||||||
@@ -13,16 +13,84 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from einops import pack, rearrange, repeat
|
from einops import pack, rearrange, repeat
|
||||||
|
from cosyvoice.utils.common import mask_to_bias
|
||||||
|
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||||
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
||||||
from matcha.models.components.transformer import BasicTransformerBlock
|
from matcha.models.components.transformer import BasicTransformerBlock
|
||||||
|
|
||||||
|
|
||||||
|
class Transpose(torch.nn.Module):
|
||||||
|
def __init__(self, dim0: int, dim1: int):
|
||||||
|
super().__init__()
|
||||||
|
self.dim0 = dim0
|
||||||
|
self.dim1 = dim1
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = torch.transpose(x, self.dim0, self.dim1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalBlock1D(Block1D):
|
||||||
|
def __init__(self, dim: int, dim_out: int):
|
||||||
|
super(CausalBlock1D, self).__init__(dim, dim_out)
|
||||||
|
self.block = torch.nn.Sequential(
|
||||||
|
CausalConv1d(dim, dim_out, 3),
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.LayerNorm(dim_out),
|
||||||
|
Transpose(1, 2),
|
||||||
|
nn.Mish(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
||||||
|
output = self.block(x * mask)
|
||||||
|
return output * mask
|
||||||
|
|
||||||
|
|
||||||
|
class CausalResnetBlock1D(ResnetBlock1D):
|
||||||
|
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8):
|
||||||
|
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
||||||
|
self.block1 = CausalBlock1D(dim, dim_out)
|
||||||
|
self.block2 = CausalBlock1D(dim_out, dim_out)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1d(torch.nn.Conv1d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
padding_mode: str = 'zeros',
|
||||||
|
device=None,
|
||||||
|
dtype=None
|
||||||
|
) -> None:
|
||||||
|
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
||||||
|
kernel_size, stride,
|
||||||
|
padding=0, dilation=dilation,
|
||||||
|
groups=groups, bias=bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
assert stride == 1
|
||||||
|
self.causal_padding = (kernel_size - 1, 0)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = F.pad(x, self.causal_padding)
|
||||||
|
x = super(CausalConv1d, self).forward(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ConditionalDecoder(nn.Module):
|
class ConditionalDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
causal=False,
|
||||||
channels=(256, 256),
|
channels=(256, 256),
|
||||||
dropout=0.05,
|
dropout=0.05,
|
||||||
attention_head_dim=64,
|
attention_head_dim=64,
|
||||||
@@ -39,7 +107,7 @@ class ConditionalDecoder(nn.Module):
|
|||||||
channels = tuple(channels)
|
channels = tuple(channels)
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
self.causal = causal
|
||||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||||
time_embed_dim = channels[0] * 4
|
time_embed_dim = channels[0] * 4
|
||||||
self.time_mlp = TimestepEmbedding(
|
self.time_mlp = TimestepEmbedding(
|
||||||
@@ -56,7 +124,7 @@ class ConditionalDecoder(nn.Module):
|
|||||||
input_channel = output_channel
|
input_channel = output_channel
|
||||||
output_channel = channels[i]
|
output_channel = channels[i]
|
||||||
is_last = i == len(channels) - 1
|
is_last = i == len(channels) - 1
|
||||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||||
transformer_blocks = nn.ModuleList(
|
transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
BasicTransformerBlock(
|
BasicTransformerBlock(
|
||||||
@@ -70,14 +138,14 @@ class ConditionalDecoder(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
downsample = (
|
downsample = (
|
||||||
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||||
)
|
)
|
||||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||||
|
|
||||||
for _ in range(num_mid_blocks):
|
for _ in range(num_mid_blocks):
|
||||||
input_channel = channels[-1]
|
input_channel = channels[-1]
|
||||||
out_channels = channels[-1]
|
out_channels = channels[-1]
|
||||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||||
|
|
||||||
transformer_blocks = nn.ModuleList(
|
transformer_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@@ -99,7 +167,11 @@ class ConditionalDecoder(nn.Module):
|
|||||||
input_channel = channels[i] * 2
|
input_channel = channels[i] * 2
|
||||||
output_channel = channels[i + 1]
|
output_channel = channels[i + 1]
|
||||||
is_last = i == len(channels) - 2
|
is_last = i == len(channels) - 2
|
||||||
resnet = ResnetBlock1D(
|
resnet = CausalResnetBlock1D(
|
||||||
|
dim=input_channel,
|
||||||
|
dim_out=output_channel,
|
||||||
|
time_emb_dim=time_embed_dim,
|
||||||
|
) if self.causal else ResnetBlock1D(
|
||||||
dim=input_channel,
|
dim=input_channel,
|
||||||
dim_out=output_channel,
|
dim_out=output_channel,
|
||||||
time_emb_dim=time_embed_dim,
|
time_emb_dim=time_embed_dim,
|
||||||
@@ -119,10 +191,10 @@ class ConditionalDecoder(nn.Module):
|
|||||||
upsample = (
|
upsample = (
|
||||||
Upsample1D(output_channel, use_conv_transpose=True)
|
Upsample1D(output_channel, use_conv_transpose=True)
|
||||||
if not is_last
|
if not is_last
|
||||||
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||||
)
|
)
|
||||||
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
||||||
self.final_block = Block1D(channels[-1], channels[-1])
|
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
||||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||||
self.initialize_weights()
|
self.initialize_weights()
|
||||||
|
|
||||||
@@ -175,7 +247,9 @@ class ConditionalDecoder(nn.Module):
|
|||||||
mask_down = masks[-1]
|
mask_down = masks[-1]
|
||||||
x = resnet(x, mask_down, t)
|
x = resnet(x, mask_down, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||||
|
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x = transformer_block(
|
x = transformer_block(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -192,7 +266,9 @@ class ConditionalDecoder(nn.Module):
|
|||||||
for resnet, transformer_blocks in self.mid_blocks:
|
for resnet, transformer_blocks in self.mid_blocks:
|
||||||
x = resnet(x, mask_mid, t)
|
x = resnet(x, mask_mid, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||||
|
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x = transformer_block(
|
x = transformer_block(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -207,7 +283,9 @@ class ConditionalDecoder(nn.Module):
|
|||||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||||
x = resnet(x, mask_up, t)
|
x = resnet(x, mask_up, t)
|
||||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||||
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
||||||
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||||
|
attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
||||||
for transformer_block in transformer_blocks:
|
for transformer_block in transformer_blocks:
|
||||||
x = transformer_block(
|
x = transformer_block(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -218,4 +296,4 @@ class ConditionalDecoder(nn.Module):
|
|||||||
x = upsample(x * mask_up)
|
x = upsample(x * mask_up)
|
||||||
x = self.final_block(x, mask_up)
|
x = self.final_block(x, mask_up)
|
||||||
output = self.final_proj(x * mask_up)
|
output = self.final_proj(x * mask_up)
|
||||||
return output * mask
|
return output * mask
|
||||||
@@ -146,3 +146,83 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|||||||
feat = feat[:, :, mel_len1:]
|
feat = feat[:, :, mel_len1:]
|
||||||
assert feat.shape[2] == mel_len2
|
assert feat.shape[2] == mel_len2
|
||||||
return feat, flow_cache
|
return feat, flow_cache
|
||||||
|
|
||||||
|
|
||||||
|
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
input_size: int = 512,
|
||||||
|
output_size: int = 80,
|
||||||
|
spk_embed_dim: int = 192,
|
||||||
|
output_type: str = "mel",
|
||||||
|
vocab_size: int = 4096,
|
||||||
|
input_frame_rate: int = 50,
|
||||||
|
only_mask_loss: bool = True,
|
||||||
|
encoder: torch.nn.Module = None,
|
||||||
|
decoder: torch.nn.Module = None,
|
||||||
|
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||||
|
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||||
|
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||||
|
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||||
|
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
||||||
|
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
||||||
|
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = output_size
|
||||||
|
self.decoder_conf = decoder_conf
|
||||||
|
self.mel_feat_conf = mel_feat_conf
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.output_type = output_type
|
||||||
|
self.input_frame_rate = input_frame_rate
|
||||||
|
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||||
|
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
||||||
|
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
||||||
|
self.decoder = decoder
|
||||||
|
self.only_mask_loss = only_mask_loss
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference(self,
|
||||||
|
token,
|
||||||
|
token_len,
|
||||||
|
prompt_token,
|
||||||
|
prompt_token_len,
|
||||||
|
prompt_feat,
|
||||||
|
prompt_feat_len,
|
||||||
|
embedding,
|
||||||
|
finalize):
|
||||||
|
assert token.shape[0] == 1
|
||||||
|
# xvec projection
|
||||||
|
embedding = F.normalize(embedding, dim=1)
|
||||||
|
embedding = self.spk_embed_affine_layer(embedding)
|
||||||
|
|
||||||
|
# concat text and prompt_text
|
||||||
|
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
||||||
|
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||||
|
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||||
|
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||||
|
|
||||||
|
# text encode
|
||||||
|
h, h_lengths = self.encoder(token, token_len)
|
||||||
|
if finalize is False:
|
||||||
|
h = h[:, :-self.encoder.pre_lookahead_layer.pre_lookahead_len * self.encoder.up_layer.stride]
|
||||||
|
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||||
|
h = self.encoder_proj(h)
|
||||||
|
|
||||||
|
# get conditions
|
||||||
|
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
|
||||||
|
conds[:, :mel_len1] = prompt_feat
|
||||||
|
conds = conds.transpose(1, 2)
|
||||||
|
|
||||||
|
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||||
|
feat, _ = self.decoder(
|
||||||
|
mu=h.transpose(1, 2).contiguous(),
|
||||||
|
mask=mask.unsqueeze(1),
|
||||||
|
spks=embedding,
|
||||||
|
cond=conds,
|
||||||
|
n_timesteps=10
|
||||||
|
)
|
||||||
|
feat = feat[:, :, mel_len1:]
|
||||||
|
assert feat.shape[2] == mel_len2
|
||||||
|
return feat, None
|
||||||
@@ -89,17 +89,25 @@ class ConditionalCFM(BASECFM):
|
|||||||
sol = []
|
sol = []
|
||||||
|
|
||||||
for step in range(1, len(t_span)):
|
for step in range(1, len(t_span)):
|
||||||
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
|
|
||||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||||
if self.inference_cfg_rate > 0:
|
if self.inference_cfg_rate > 0:
|
||||||
cfg_dphi_dt = self.forward_estimator(
|
x_in = torch.concat([x, x], dim=0)
|
||||||
x, mask,
|
mask_in = torch.concat([mask, mask], dim=0)
|
||||||
torch.zeros_like(mu), t,
|
mu_in = torch.concat([mu, torch.zeros_like(mu).to(x.device)], dim=0)
|
||||||
torch.zeros_like(spks) if spks is not None else None,
|
t_in = torch.concat([t, t], dim=0)
|
||||||
torch.zeros_like(cond)
|
spks_in = torch.concat([spks, torch.zeros_like(spks).to(x.device)], dim=0) if spks is not None else None
|
||||||
)
|
cond_in = torch.concat([cond, torch.zeros_like(cond).to(x.device)], dim=0) if cond is not None else None
|
||||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
else:
|
||||||
self.inference_cfg_rate * cfg_dphi_dt)
|
x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
|
||||||
|
dphi_dt = self.forward_estimator(
|
||||||
|
x_in, mask_in,
|
||||||
|
mu_in, t_in,
|
||||||
|
spks_in,
|
||||||
|
cond_in
|
||||||
|
)
|
||||||
|
if self.inference_cfg_rate > 0:
|
||||||
|
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||||
|
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||||
x = x + dt * dphi_dt
|
x = x + dt * dphi_dt
|
||||||
t = t + dt
|
t = t + dt
|
||||||
sol.append(x)
|
sol.append(x)
|
||||||
@@ -163,3 +171,37 @@ class ConditionalCFM(BASECFM):
|
|||||||
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
||||||
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
||||||
return loss, y
|
return loss, y
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConditionalCFM(ConditionalCFM):
|
||||||
|
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||||
|
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
||||||
|
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||||
|
"""Forward diffusion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mu (torch.Tensor): output of encoder
|
||||||
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
|
mask (torch.Tensor): output_mask
|
||||||
|
shape: (batch_size, 1, mel_timesteps)
|
||||||
|
n_timesteps (int): number of diffusion steps
|
||||||
|
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||||
|
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||||
|
shape: (batch_size, spk_emb_dim)
|
||||||
|
cond: Not used but kept for future purposes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sample: generated mel-spectrogram
|
||||||
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
|
"""
|
||||||
|
|
||||||
|
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
||||||
|
z[:] = 0
|
||||||
|
# fix prompt and overlap part mu and z
|
||||||
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||||
|
if self.t_scheduler == 'cosine':
|
||||||
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||||
|
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import base64
|
|||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
from whisper.tokenizer import Tokenizer
|
from whisper.tokenizer import Tokenizer
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@@ -234,3 +236,37 @@ def get_tokenizer(
|
|||||||
return Tokenizer(
|
return Tokenizer(
|
||||||
encoding=encoding, num_languages=num_languages, language=language, task=task
|
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenTokenizer():
|
||||||
|
def __init__(self, token_path, skip_special_tokens=True):
|
||||||
|
special_tokens = {
|
||||||
|
'eos_token': '<|endoftext|>',
|
||||||
|
'pad_token': '<|endoftext|>',
|
||||||
|
'additional_special_tokens': [
|
||||||
|
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||||
|
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||||
|
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||||
|
'[quick_breath]',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
||||||
|
self.tokenizer.add_special_tokens(special_tokens)
|
||||||
|
self.skip_special_tokens = skip_special_tokens
|
||||||
|
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
tokens = self.tokenizer([text], return_tensors="pt")
|
||||||
|
tokens = tokens["input_ids"][0].cpu().tolist()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
tokens = torch.tensor(tokens, dtype=torch.int64)
|
||||||
|
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
||||||
|
return text
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_qwen_tokenizer(
|
||||||
|
token_path: str,
|
||||||
|
skip_special_tokens: bool
|
||||||
|
) -> QwenTokenizer:
|
||||||
|
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
||||||
@@ -49,8 +49,8 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = self_attn
|
self.self_attn = self_attn
|
||||||
self.feed_forward = feed_forward
|
self.feed_forward = feed_forward
|
||||||
self.norm1 = nn.LayerNorm(size, eps=1e-5)
|
self.norm1 = nn.LayerNorm(size, eps=1e-12)
|
||||||
self.norm2 = nn.LayerNorm(size, eps=1e-5)
|
self.norm2 = nn.LayerNorm(size, eps=1e-12)
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
self.dropout = nn.Dropout(dropout_rate)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
@@ -142,17 +142,17 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self.feed_forward = feed_forward
|
self.feed_forward = feed_forward
|
||||||
self.feed_forward_macaron = feed_forward_macaron
|
self.feed_forward_macaron = feed_forward_macaron
|
||||||
self.conv_module = conv_module
|
self.conv_module = conv_module
|
||||||
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
|
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
||||||
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
|
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
||||||
if feed_forward_macaron is not None:
|
if feed_forward_macaron is not None:
|
||||||
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
|
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
||||||
self.ff_scale = 0.5
|
self.ff_scale = 0.5
|
||||||
else:
|
else:
|
||||||
self.ff_scale = 1.0
|
self.ff_scale = 1.0
|
||||||
if self.conv_module is not None:
|
if self.conv_module is not None:
|
||||||
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
|
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
||||||
self.norm_final = nn.LayerNorm(
|
self.norm_final = nn.LayerNorm(
|
||||||
size, eps=1e-5) # for the final output of the block
|
size, eps=1e-12) # for the final output of the block
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
self.dropout = nn.Dropout(dropout_rate)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
|
|||||||
@@ -153,3 +153,14 @@ def set_all_random_seed(seed):
|
|||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
assert mask.dtype == torch.bool
|
||||||
|
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
||||||
|
mask = mask.to(dtype)
|
||||||
|
# attention mask bias
|
||||||
|
# NOTE(Mddct): torch.finfo jit issues
|
||||||
|
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
||||||
|
mask = (1.0 - mask) * torch.finfo(dtype).min
|
||||||
|
return mask
|
||||||
Reference in New Issue
Block a user