add stream code

This commit is contained in:
lyuxiang.lx
2024-07-23 00:02:30 +08:00
parent 2895d99b9a
commit a13411c561
4 changed files with 123 additions and 44 deletions

View File

@@ -158,6 +158,7 @@ class TransformerLM(torch.nn.Module):
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
stream: bool = False,
) -> torch.Tensor:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
@@ -199,8 +200,13 @@ class TransformerLM(torch.nn.Module):
top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
# in stream mode, yield token one by one
if stream is True:
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
out_tokens.append(top_ids)
offset += lm_input.size(1)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
return torch.tensor([out_tokens], dtype=torch.int64, device=device)
# in non-stream mode, yield all token
if stream is False:
yield torch.tensor([out_tokens], dtype=torch.int64, device=device)