diff --git a/.gitignore b/.gitignore index 12b53ef..4a9f23f 100644 --- a/.gitignore +++ b/.gitignore @@ -48,4 +48,5 @@ compile_commands.json *.pt pretrained_models/* *_pb2_grpc.py -*_pb2.py \ No newline at end of file +*_pb2.py +*.tar \ No newline at end of file diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index bcade73..d4bea1e 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -18,7 +18,7 @@ from hyperpyyaml import load_hyperpyyaml from modelscope import snapshot_download import torch from cosyvoice.cli.frontend import CosyVoiceFrontEnd -from cosyvoice.cli.model import CosyVoiceModel +from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.utils.file_utils import logging @@ -118,3 +118,35 @@ class CosyVoice: logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output start_time = time.time() + +class CosyVoice2(CosyVoice): + + def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True): + instruct = True if '-Instruct' in model_dir else False + self.model_dir = model_dir + if not os.path.exists(model_dir): + model_dir = snapshot_download(model_dir) + with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: + configs = load_hyperpyyaml(f) + self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], + configs['feat_extractor'], + '{}/campplus.onnx'.format(model_dir), + '{}/speech_tokenizer_v2.onnx'.format(model_dir), + '{}/spk2info.pt'.format(model_dir), + instruct, + configs['allowed_special']) + if torch.cuda.is_available() is False and (fp16 is True or load_jit is True): + load_jit = False + fp16 = False + logging.warning('cpu do not support fp16 and jit, force set to False') + self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16) + self.model.load('{}/llm.pt'.format(model_dir), + '{}/flow.pt'.format(model_dir), + '{}/hift.pt'.format(model_dir)) + if load_jit: + self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), + '{}/llm.llm.fp16.zip'.format(model_dir), + '{}/flow.encoder.fp32.zip'.format(model_dir)) + if load_onnx: + self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir)) + del configs \ No newline at end of file diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 1fcc31f..ce53f04 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -57,15 +57,15 @@ class CosyVoiceModel: self.hift_cache_dict = {} def load(self, llm_model, flow_model, hift_model): - self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False) + self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) self.llm.to(self.device).eval() if self.fp16 is True: self.llm.half() - self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False) + self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) self.flow.to(self.device).eval() # in case hift_model is a hifigan model hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} - self.hift.load_state_dict(hift_state_dict, strict=False) + self.hift.load_state_dict(hift_state_dict, strict=True) self.hift.to(self.device).eval() def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): @@ -254,3 +254,175 @@ class CosyVoiceModel: self.llm_end_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) + + +class CosyVoice2Model: + + def __init__(self, + llm: torch.nn.Module, + flow: torch.nn.Module, + hift: torch.nn.Module, + fp16: bool): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.llm = llm + self.flow = flow + self.hift = hift + self.fp16 = fp16 + self.token_min_hop_len = 1 * self.flow.input_frame_rate + self.token_max_hop_len = 2 * self.flow.input_frame_rate + self.token_right_context = self.flow.encoder.pre_lookahead_layer.pre_lookahead_len + # hift cache + self.mel_cache_len = 8 + self.source_cache_len = int(self.mel_cache_len * 480) + # speech fade in out + self.speech_window = np.hamming(2 * self.source_cache_len) + # rtf and decoding related + self.stream_scale_factor = 1 + assert self.stream_scale_factor == 1, 'fix stream_scale_factor to 1 as we haven\'t implement cache in flow matching yet, this constraint will be loosen in the future' + self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() + self.lock = threading.Lock() + # dict used to store session related variable + self.tts_speech_token_dict = {} + self.llm_end_dict = {} + self.hift_cache_dict = {} + + def load(self, llm_model, flow_model, hift_model): + self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True) + self.llm.to(self.device).eval() + if self.fp16 is True: + self.llm.half() + self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True) + self.flow.to(self.device).eval() + # in case hift_model is a hifigan model + hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()} + self.hift.load_state_dict(hift_state_dict, strict=True) + self.hift.to(self.device).eval() + + def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): + assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model" + llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device) + self.llm.text_encoder = llm_text_encoder + llm_llm = torch.jit.load(llm_llm_model, map_location=self.device) + self.llm.llm = llm_llm + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + self.flow.encoder = flow_encoder + + def load_onnx(self, flow_decoder_estimator_model): + import onnxruntime + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] + del self.flow.decoder.estimator + self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers) + + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): + if self.fp16 is True: + llm_embedding = llm_embedding.half() + with self.llm_context: + for i in self.llm.inference(text=text.to(self.device), + text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), + prompt_text=prompt_text.to(self.device), + prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + embedding=llm_embedding.to(self.device)): + self.tts_speech_token_dict[uuid].append(i) + self.llm_end_dict[uuid] = True + + def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0): + tts_mel, _ = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device), + finalize=finalize) + tts_mel = tts_mel[:, :, token_offset * self.flow.encoder.up_layer.stride:] + # append hift cache + if self.hift_cache_dict[uuid] is not None: + hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] + tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = torch.zeros(1, 1, 0) + # keep overlap mel and hift cache + if finalize is False: + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) + self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:], + 'source': tts_source[:, :, -self.source_cache_len:], + 'speech': tts_speech[:, -self.source_cache_len:]} + tts_speech = tts_speech[:, :-self.source_cache_len] + else: + if speed != 1.0: + assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode' + tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear') + tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window) + return tts_speech + + def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), + prompt_text=torch.zeros(1, 0, dtype=torch.int32), + llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), + flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), + prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): + # this_uuid is used to track variables related to this inference thread + this_uuid = str(uuid.uuid1()) + with self.lock: + self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False + self.hift_cache_dict[this_uuid] = None + p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) + p.start() + if stream is True: + token_hop_len, token_offset = self.token_min_hop_len, 0 + self.flow.encoder.static_chunk_size = self.token_min_hop_len + self.flow.decoder.estimator.static_chunk_size = self.token_min_hop_len * self.flow.encoder.up_layer.stride + while True: + time.sleep(0.1) + if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= token_hop_len + self.token_right_context: + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + token_hop_len + self.token_right_context]) \ + .unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + token_offset=token_offset, + finalize=False) + token_offset += token_hop_len + yield {'tts_speech': this_tts_speech.cpu()} + # increase token_hop_len for better speech quality + token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) + if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < token_hop_len + self.token_right_context: + break + p.join() + # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + token_offset=token_offset, + finalize=True) + yield {'tts_speech': this_tts_speech.cpu()} + else: + # deal with all tokens + p.join() + self.flow.encoder.static_chunk_size = 0 + self.flow.decoder.estimator.static_chunk_size = 0 + this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True, + speed=speed) + yield {'tts_speech': this_tts_speech.cpu()} + with self.lock: + self.tts_speech_token_dict.pop(this_uuid) + self.llm_end_dict.pop(this_uuid) \ No newline at end of file diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index b101d72..dfb3c07 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -13,16 +13,84 @@ # limitations under the License. import torch import torch.nn as nn +import torch.nn.functional as F from einops import pack, rearrange, repeat +from cosyvoice.utils.common import mask_to_bias +from cosyvoice.utils.mask import add_optional_chunk_mask from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D from matcha.models.components.transformer import BasicTransformerBlock +class Transpose(torch.nn.Module): + def __init__(self, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor): + x = torch.transpose(x, self.dim0, self.dim1) + return x + + +class CausalBlock1D(Block1D): + def __init__(self, dim: int, dim_out: int): + super(CausalBlock1D, self).__init__(dim, dim_out) + self.block = torch.nn.Sequential( + CausalConv1d(dim, dim_out, 3), + Transpose(1, 2), + nn.LayerNorm(dim_out), + Transpose(1, 2), + nn.Mish(), + ) + + def forward(self, x: torch.Tensor, mask: torch.Tensor): + output = self.block(x * mask) + return output * mask + + +class CausalResnetBlock1D(ResnetBlock1D): + def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8): + super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) + self.block1 = CausalBlock1D(dim, dim_out) + self.block2 = CausalBlock1D(dim_out, dim_out) + + +class CausalConv1d(torch.nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + super(CausalConv1d, self).__init__(in_channels, out_channels, + kernel_size, stride, + padding=0, dilation=dilation, + groups=groups, bias=bias, + padding_mode=padding_mode, + device=device, dtype=dtype + ) + assert stride == 1 + self.causal_padding = (kernel_size - 1, 0) + + def forward(self, x: torch.Tensor): + x = F.pad(x, self.causal_padding) + x = super(CausalConv1d, self).forward(x) + return x + + class ConditionalDecoder(nn.Module): def __init__( self, in_channels, out_channels, + causal=False, channels=(256, 256), dropout=0.05, attention_head_dim=64, @@ -39,7 +107,7 @@ class ConditionalDecoder(nn.Module): channels = tuple(channels) self.in_channels = in_channels self.out_channels = out_channels - + self.causal = causal self.time_embeddings = SinusoidalPosEmb(in_channels) time_embed_dim = channels[0] * 4 self.time_mlp = TimestepEmbedding( @@ -56,7 +124,7 @@ class ConditionalDecoder(nn.Module): input_channel = output_channel output_channel = channels[i] is_last = i == len(channels) - 1 - resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( @@ -70,14 +138,14 @@ class ConditionalDecoder(nn.Module): ] ) downsample = ( - Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) for _ in range(num_mid_blocks): input_channel = channels[-1] out_channels = channels[-1] - resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ @@ -99,7 +167,11 @@ class ConditionalDecoder(nn.Module): input_channel = channels[i] * 2 output_channel = channels[i + 1] is_last = i == len(channels) - 2 - resnet = ResnetBlock1D( + resnet = CausalResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) if self.causal else ResnetBlock1D( dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim, @@ -119,10 +191,10 @@ class ConditionalDecoder(nn.Module): upsample = ( Upsample1D(output_channel, use_conv_transpose=True) if not is_last - else nn.Conv1d(output_channel, output_channel, 3, padding=1) + else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) - self.final_block = Block1D(channels[-1], channels[-1]) + self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) self.initialize_weights() @@ -175,7 +247,9 @@ class ConditionalDecoder(nn.Module): mask_down = masks[-1] x = resnet(x, mask_down, t) x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) + # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) + attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) + attn_mask = mask_to_bias(attn_mask==1, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, @@ -192,7 +266,9 @@ class ConditionalDecoder(nn.Module): for resnet, transformer_blocks in self.mid_blocks: x = resnet(x, mask_mid, t) x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) + # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) + attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) + attn_mask = mask_to_bias(attn_mask==1, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, @@ -207,7 +283,9 @@ class ConditionalDecoder(nn.Module): x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x = resnet(x, mask_up, t) x = rearrange(x, "b c t -> b t c").contiguous() - attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) + # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) + attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) + attn_mask = mask_to_bias(attn_mask==1, x.dtype) for transformer_block in transformer_blocks: x = transformer_block( hidden_states=x, @@ -218,4 +296,4 @@ class ConditionalDecoder(nn.Module): x = upsample(x * mask_up) x = self.final_block(x, mask_up) output = self.final_proj(x * mask_up) - return output * mask + return output * mask \ No newline at end of file diff --git a/cosyvoice/flow/flow.py b/cosyvoice/flow/flow.py index eea705b..459e3fc 100644 --- a/cosyvoice/flow/flow.py +++ b/cosyvoice/flow/flow.py @@ -146,3 +146,83 @@ class MaskedDiffWithXvec(torch.nn.Module): feat = feat[:, :, mel_len1:] assert feat.shape[2] == mel_len2 return feat, flow_cache + + +class CausalMaskedDiffWithXvec(torch.nn.Module): + def __init__(self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 4096, + input_frame_rate: int = 50, + only_mask_loss: bool = True, + encoder: torch.nn.Module = None, + decoder: torch.nn.Module = None, + decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, + 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', + 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), + 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, + 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, + mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, + 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = nn.Embedding(vocab_size, input_size) + self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) + self.encoder = encoder + self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) + self.decoder = decoder + self.only_mask_loss = only_mask_loss + + @torch.inference_mode() + def inference(self, + token, + token_len, + prompt_token, + prompt_token_len, + prompt_feat, + prompt_feat_len, + embedding, + finalize): + assert token.shape[0] == 1 + # xvec projection + embedding = F.normalize(embedding, dim=1) + embedding = self.spk_embed_affine_layer(embedding) + + # concat text and prompt_text + token_len1, token_len2 = prompt_token.shape[1], token.shape[1] + token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len + mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) + token = self.input_embedding(torch.clamp(token, min=0)) * mask + + # text encode + h, h_lengths = self.encoder(token, token_len) + if finalize is False: + h = h[:, :-self.encoder.pre_lookahead_layer.pre_lookahead_len * self.encoder.up_layer.stride] + mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] + h = self.encoder_proj(h) + + # get conditions + conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device) + conds[:, :mel_len1] = prompt_feat + conds = conds.transpose(1, 2) + + mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) + feat, _ = self.decoder( + mu=h.transpose(1, 2).contiguous(), + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=10 + ) + feat = feat[:, :, mel_len1:] + assert feat.shape[2] == mel_len2 + return feat, None \ No newline at end of file diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index d011304..1e59f13 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -89,17 +89,25 @@ class ConditionalCFM(BASECFM): sol = [] for step in range(1, len(t_span)): - dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond) # Classifier-Free Guidance inference introduced in VoiceBox if self.inference_cfg_rate > 0: - cfg_dphi_dt = self.forward_estimator( - x, mask, - torch.zeros_like(mu), t, - torch.zeros_like(spks) if spks is not None else None, - torch.zeros_like(cond) - ) - dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - - self.inference_cfg_rate * cfg_dphi_dt) + x_in = torch.concat([x, x], dim=0) + mask_in = torch.concat([mask, mask], dim=0) + mu_in = torch.concat([mu, torch.zeros_like(mu).to(x.device)], dim=0) + t_in = torch.concat([t, t], dim=0) + spks_in = torch.concat([spks, torch.zeros_like(spks).to(x.device)], dim=0) if spks is not None else None + cond_in = torch.concat([cond, torch.zeros_like(cond).to(x.device)], dim=0) if cond is not None else None + else: + x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond + dphi_dt = self.forward_estimator( + x_in, mask_in, + mu_in, t_in, + spks_in, + cond_in + ) + if self.inference_cfg_rate > 0: + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) x = x + dt * dphi_dt t = t + dt sol.append(x) @@ -163,3 +171,37 @@ class ConditionalCFM(BASECFM): pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) return loss, y + + +class CausalConditionalCFM(ConditionalCFM): + def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): + super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator) + self.rand_noise = torch.randn([1, 80, 50 * 300]) + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + + z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature + z[:] = 0 + # fix prompt and overlap part mu and z + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) + if self.t_scheduler == 'cosine': + t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None diff --git a/cosyvoice/tokenizer/tokenizer.py b/cosyvoice/tokenizer/tokenizer.py index caecf26..3cbe8b5 100644 --- a/cosyvoice/tokenizer/tokenizer.py +++ b/cosyvoice/tokenizer/tokenizer.py @@ -2,6 +2,8 @@ import base64 import os from functools import lru_cache from typing import Optional +import torch +from transformers import AutoTokenizer from whisper.tokenizer import Tokenizer import tiktoken @@ -234,3 +236,37 @@ def get_tokenizer( return Tokenizer( encoding=encoding, num_languages=num_languages, language=language, task=task ) + + +class QwenTokenizer(): + def __init__(self, token_path, skip_special_tokens=True): + special_tokens = { + 'eos_token': '<|endoftext|>', + 'pad_token': '<|endoftext|>', + 'additional_special_tokens': [ + '<|im_start|>', '<|im_end|>', '<|endofprompt|>', + '[breath]', '', '', '[noise]', + '[laughter]', '[cough]', '[clucking]', '[accent]', + '[quick_breath]', + ] + } + self.tokenizer = AutoTokenizer.from_pretrained(token_path) + self.tokenizer.add_special_tokens(special_tokens) + self.skip_special_tokens = skip_special_tokens + + def encode(self, text, **kwargs): + tokens = self.tokenizer([text], return_tensors="pt") + tokens = tokens["input_ids"][0].cpu().tolist() + return tokens + + def decode(self, tokens): + tokens = torch.tensor(tokens, dtype=torch.int64) + text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0] + return text + +@lru_cache(maxsize=None) +def get_qwen_tokenizer( + token_path: str, + skip_special_tokens: bool +) -> QwenTokenizer: + return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens) \ No newline at end of file diff --git a/cosyvoice/transformer/encoder_layer.py b/cosyvoice/transformer/encoder_layer.py index dfd758b..efbb12d 100644 --- a/cosyvoice/transformer/encoder_layer.py +++ b/cosyvoice/transformer/encoder_layer.py @@ -49,8 +49,8 @@ class TransformerEncoderLayer(nn.Module): super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward - self.norm1 = nn.LayerNorm(size, eps=1e-5) - self.norm2 = nn.LayerNorm(size, eps=1e-5) + self.norm1 = nn.LayerNorm(size, eps=1e-12) + self.norm2 = nn.LayerNorm(size, eps=1e-12) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before @@ -142,17 +142,17 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = feed_forward self.feed_forward_macaron = feed_forward_macaron self.conv_module = conv_module - self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module - self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module + self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module + self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module if feed_forward_macaron is not None: - self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) + self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12) self.ff_scale = 0.5 else: self.ff_scale = 1.0 if self.conv_module is not None: - self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module + self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module self.norm_final = nn.LayerNorm( - size, eps=1e-5) # for the final output of the block + size, eps=1e-12) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 25bc835..2e12ad2 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -153,3 +153,14 @@ def set_all_random_seed(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + + +def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + assert mask.dtype == torch.bool + assert dtype in [torch.float32, torch.bfloat16, torch.float16] + mask = mask.to(dtype) + # attention mask bias + # NOTE(Mddct): torch.finfo jit issues + # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min + mask = (1.0 - mask) * torch.finfo(dtype).min + return mask \ No newline at end of file