update stream code

This commit is contained in:
lyuxiang.lx
2024-07-30 16:11:28 +08:00
parent 02f941d348
commit f4e70e222c
15 changed files with 182 additions and 109 deletions

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Union
from typing import Dict, Optional, Callable, List, Generator
import torch
from torch import nn
import torch.nn.functional as F
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
speech_token_size: int,
text_encoder: torch.nn.Module,
llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
spk_embed_dim: int = 192,
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
# 4. sampling method
self.sampling = sampling
def encode(
self,
text: torch.Tensor,
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
def sampling_ids(
self,
weighted_scores: torch.Tensor,
sampling: Union[bool, int, float] = True,
beam_size: int = 1,
decoded_tokens: List,
sampling: int,
ignore_eos: bool = True,
):
while True:
prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
top_ids = prob.multinomial(beam_size, replacement=True)
top_ids = indices[top_ids]
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
break
return top_ids
@@ -154,12 +156,10 @@ class TransformerLM(torch.nn.Module):
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
beam_size: int = 1,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
stream: bool = False,
) -> torch.Tensor:
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
@@ -197,16 +197,11 @@ class TransformerLM(torch.nn.Module):
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
# in stream mode, yield token one by one
if stream is True:
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
out_tokens.append(top_ids)
offset += lm_input.size(1)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
# in non-stream mode, yield all token
if stream is False:
yield torch.tensor([out_tokens], dtype=torch.int64, device=device)