mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-04 17:39:25 +08:00
add stream code
This commit is contained in:
@@ -158,6 +158,7 @@ class TransformerLM(torch.nn.Module):
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
stream: bool = False,
|
||||
) -> torch.Tensor:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
@@ -199,8 +200,13 @@ class TransformerLM(torch.nn.Module):
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
# in stream mode, yield token one by one
|
||||
if stream is True:
|
||||
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
|
||||
out_tokens.append(top_ids)
|
||||
offset += lm_input.size(1)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
return torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
||||
# in non-stream mode, yield all token
|
||||
if stream is False:
|
||||
yield torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
||||
|
||||
Reference in New Issue
Block a user