mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
update
This commit is contained in:
@@ -189,7 +189,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Dict[str, Optional[torch.Tensor]]:
|
) -> Dict[str, Optional[torch.Tensor]]:
|
||||||
if 'speech_token' not in batch:
|
if 'speech_token' not in batch:
|
||||||
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'])
|
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||||
else:
|
else:
|
||||||
token = batch['speech_token'].to(device)
|
token = batch['speech_token'].to(device)
|
||||||
token_len = batch['speech_token_len'].to(device)
|
token_len = batch['speech_token_len'].to(device)
|
||||||
@@ -322,6 +322,9 @@ class CausalMaskedDiffWithDiT(torch.nn.Module):
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Dict[str, Optional[torch.Tensor]]:
|
) -> Dict[str, Optional[torch.Tensor]]:
|
||||||
|
if 'speech_token' not in batch:
|
||||||
|
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||||
|
else:
|
||||||
token = batch['speech_token'].to(device)
|
token = batch['speech_token'].to(device)
|
||||||
token_len = batch['speech_token_len'].to(device)
|
token_len = batch['speech_token_len'].to(device)
|
||||||
feat = batch['speech_feat'].to(device)
|
feat = batch['speech_feat'].to(device)
|
||||||
|
|||||||
@@ -367,6 +367,9 @@ class Qwen2LM(TransformerLM):
|
|||||||
"""
|
"""
|
||||||
text_token = batch['text_token'].to(device)
|
text_token = batch['text_token'].to(device)
|
||||||
text_token_len = batch['text_token_len'].to(device)
|
text_token_len = batch['text_token_len'].to(device)
|
||||||
|
if 'speech_token' not in batch:
|
||||||
|
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||||
|
else:
|
||||||
speech_token = batch['speech_token'].to(device)
|
speech_token = batch['speech_token'].to(device)
|
||||||
speech_token_len = batch['speech_token_len'].to(device)
|
speech_token_len = batch['speech_token_len'].to(device)
|
||||||
|
|
||||||
@@ -686,8 +689,12 @@ class CosyVoice3LM(Qwen2LM):
|
|||||||
"""
|
"""
|
||||||
text_token = batch['text_token'].to(device)
|
text_token = batch['text_token'].to(device)
|
||||||
text_token_len = batch['text_token_len'].to(device)
|
text_token_len = batch['text_token_len'].to(device)
|
||||||
|
if 'speech_token' not in batch:
|
||||||
|
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||||
|
else:
|
||||||
speech_token = batch['speech_token'].to(device)
|
speech_token = batch['speech_token'].to(device)
|
||||||
speech_token_len = batch['speech_token_len'].to(device)
|
speech_token_len = batch['speech_token_len'].to(device)
|
||||||
|
|
||||||
# NOTE should append instruct_token to sequence, not implemented yet
|
# NOTE should append instruct_token to sequence, not implemented yet
|
||||||
instruct_token = batch['instruct_token'].to(device)
|
instruct_token = batch['instruct_token'].to(device)
|
||||||
instruct_token_len = batch['instruct_token_len'].to(device)
|
instruct_token_len = batch['instruct_token_len'].to(device)
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
import onnxruntime
|
import onnxruntime
|
||||||
import torch, random
|
import torch, random
|
||||||
from torch import nn
|
|
||||||
import os
|
import os
|
||||||
import whisper
|
|
||||||
import numpy as np
|
|
||||||
import torchaudio.compliance.kaldi as kaldi
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechTokenExtractor():
|
class SpeechTokenExtractor():
|
||||||
@@ -18,13 +14,13 @@ class SpeechTokenExtractor():
|
|||||||
sess_options=option,
|
sess_options=option,
|
||||||
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
|
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
|
||||||
|
|
||||||
def inference(self, feat, feat_lengths):
|
def inference(self, feat, feat_lengths, device):
|
||||||
speech_token = self.speech_tokenizer_session.run(None,
|
speech_token = self.speech_tokenizer_session.run(None,
|
||||||
{self.speech_tokenizer_session.get_inputs()[0].name:
|
{self.speech_tokenizer_session.get_inputs()[0].name:
|
||||||
feat.transpose(1, 2).detach().cpu().numpy(),
|
feat.transpose(1, 2).detach().cpu().numpy(),
|
||||||
self.speech_tokenizer_session.get_inputs()[1].name:
|
self.speech_tokenizer_session.get_inputs()[1].name:
|
||||||
feat_lengths.detach().cpu().numpy()})[0]
|
feat_lengths.detach().cpu().numpy()})[0]
|
||||||
return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device)
|
return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingExtractor():
|
class EmbeddingExtractor():
|
||||||
|
|||||||
@@ -150,6 +150,7 @@ compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
|||||||
feat_extractor: !ref <feat_extractor>
|
feat_extractor: !ref <feat_extractor>
|
||||||
num_frames: 960
|
num_frames: 960
|
||||||
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
|
compute_whisper_fbank: !name:cosyvoice.dataset.processor.compute_whisper_fbank
|
||||||
|
num_frames: 960
|
||||||
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
||||||
sample_rate: !ref <sample_rate>
|
sample_rate: !ref <sample_rate>
|
||||||
hop_size: 480
|
hop_size: 480
|
||||||
|
|||||||
Reference in New Issue
Block a user