This commit is contained in:
lyuxiang.lx
2026-01-29 06:13:36 +00:00
parent 66b80dbccb
commit f26cde56df
7 changed files with 90 additions and 73 deletions

View File

@@ -18,14 +18,13 @@ class SpeechTokenExtractor():
sess_options=option,
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
def inference(self, feat, feat_lengths, device):
ort_out = self.speech_tokenizer_session.run(None,
def inference(self, feat, feat_lengths):
speech_token = self.speech_tokenizer_session.run(None,
{self.speech_tokenizer_session.get_inputs()[0].name:
feat.detach().cpu().numpy(),
feat.transpose(1, 2).detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name:
feat_lengths.detach().cpu().numpy()})
speech_token, speech_token_embedding = ort_out[0], ort_out[1]
return torch.tensor(speech_token).to(device), (feat_lengths / 2).to(torch.int32).to(device)
feat_lengths.detach().cpu().numpy()})[0]
return torch.tensor(speech_token).to(feat), (feat_lengths / 4).to(torch.int32).to(feat.device)
class EmbeddingExtractor():