mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
update stream code
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional, Callable, List, Generator
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
|
||||
speech_token_size: int,
|
||||
text_encoder: torch.nn.Module,
|
||||
llm: torch.nn.Module,
|
||||
sampling: Callable,
|
||||
length_normalized_loss: bool = True,
|
||||
lsm_weight: float = 0.0,
|
||||
spk_embed_dim: int = 192,
|
||||
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
|
||||
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
||||
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
||||
|
||||
# 4. sampling method
|
||||
self.sampling = sampling
|
||||
|
||||
def encode(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
|
||||
def sampling_ids(
|
||||
self,
|
||||
weighted_scores: torch.Tensor,
|
||||
sampling: Union[bool, int, float] = True,
|
||||
beam_size: int = 1,
|
||||
decoded_tokens: List,
|
||||
sampling: int,
|
||||
ignore_eos: bool = True,
|
||||
):
|
||||
while True:
|
||||
prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
|
||||
top_ids = prob.multinomial(beam_size, replacement=True)
|
||||
top_ids = indices[top_ids]
|
||||
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
||||
break
|
||||
return top_ids
|
||||
@@ -154,12 +156,10 @@ class TransformerLM(torch.nn.Module):
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
beam_size: int = 1,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
stream: bool = False,
|
||||
) -> torch.Tensor:
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
text_len += prompt_text_len
|
||||
@@ -197,16 +197,11 @@ class TransformerLM(torch.nn.Module):
|
||||
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
|
||||
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
# in stream mode, yield token one by one
|
||||
if stream is True:
|
||||
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
|
||||
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
|
||||
out_tokens.append(top_ids)
|
||||
offset += lm_input.size(1)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
# in non-stream mode, yield all token
|
||||
if stream is False:
|
||||
yield torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
||||
|
||||
Reference in New Issue
Block a user