mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 09:29:25 +08:00
update
This commit is contained in:
@@ -1,11 +1,7 @@
|
||||
import onnxruntime
|
||||
import torch, random
|
||||
from torch import nn
|
||||
import os
|
||||
import whisper
|
||||
import numpy as np
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SpeechTokenExtractor():
|
||||
@@ -18,13 +14,13 @@ class SpeechTokenExtractor():
|
||||
sess_options=option,
|
||||
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
|
||||
|
||||
def inference(self, feat, feat_lengths):
|
||||
def inference(self, feat, feat_lengths, device):
|
||||
speech_token = self.speech_tokenizer_session.run(None,
|
||||
{self.speech_tokenizer_session.get_inputs()[0].name:
|
||||
feat.transpose(1, 2).detach().cpu().numpy(),
|
||||
self.speech_tokenizer_session.get_inputs()[1].name:
|
||||
feat_lengths.detach().cpu().numpy()})[0]
|
||||
return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device)
|
||||
return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
|
||||
|
||||
|
||||
class EmbeddingExtractor():
|
||||
|
||||
Reference in New Issue
Block a user