This commit is contained in:
lyuxiang.lx
2025-08-21 20:08:08 +08:00
parent 8c96081f94
commit 70991d7327
2 changed files with 128 additions and 4 deletions

View File

@@ -609,3 +609,79 @@ class Qwen2LM(TransformerLM):
# in stream mode, yield token one by one
yield top_ids
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
class CosyVoice3LM(Qwen2LM):
def __init__(
self,
llm_input_size: int,
llm_output_size: int,
speech_token_size: int,
llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
mix_ratio: List[int] = [5, 15],
):
torch.nn.Module.__init__(self)
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
# 2. build speech token language model related modules
self.sos = 0
self.eos = 1
self.task_id = 2
self.fill_token = 3
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
self.criterion_ce = LabelSmoothingLoss(
size=speech_token_size + 200,
padding_idx=IGNORE_ID,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
# 4. sampling method
self.sampling = sampling
self.mix_ratio = mix_ratio
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
text = self.llm.model.model.embed_tokens(text)
# 3. concat llm_input
sos_eos_emb = self.speech_embedding.weight[self.speech_token_size + self.sos].reshape(1, 1, -1)
task_id_emb = self.speech_embedding.weight[self.speech_token_size + self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
yield token