mirror of
https://github.com/FunAudioLLM/CosyVoice.git
synced 2026-02-05 18:09:24 +08:00
fix lint
This commit is contained in:
@@ -31,7 +31,6 @@ import torch
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
@@ -96,17 +95,20 @@ if __name__ == "__main__":
|
||||
# set the weight and bias of the new lm_head to 0
|
||||
new_lm_head.weight.data.zero_()
|
||||
new_lm_head.bias.data.zero_()
|
||||
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.weight
|
||||
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.bias
|
||||
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
|
||||
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
|
||||
|
||||
llm.lm_head = new_lm_head
|
||||
input_embeddings = llm.get_input_embeddings()
|
||||
|
||||
with torch.no_grad():
|
||||
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = speech_embedding.weight
|
||||
input_embeddings.weight[original_tokenizer_vocab_size+cosyvoice2_token_size+3:original_tokenizer_vocab_size+cosyvoice2_token_size+3+2] = llm_embedding.weight
|
||||
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
|
||||
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
|
||||
|
||||
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, original_tokenizer_vocab_size + cosyvoice2_token_size + 1, original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
|
||||
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
|
||||
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
|
||||
original_tokenizer_vocab_size + cosyvoice2_token_size + 2,
|
||||
original_tokenizer_vocab_size + cosyvoice2_token_size + 3]
|
||||
llm.generation_config.eos_token_id = eos_token_ids
|
||||
llm.generation_config.temperature = 1.0
|
||||
llm.generation_config.top_p = 0.8
|
||||
@@ -121,4 +123,4 @@ if __name__ == "__main__":
|
||||
|
||||
TEMPLATE = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- '<|sos|>' + message['content'] + '<|task_id|>' }}{%- elif message['role'] == 'assistant' %}{{- message['content']}}{%- endif %}{%- endfor %}"
|
||||
tokenizer.chat_template = TEMPLATE
|
||||
tokenizer.save_pretrained(args.save_path)
|
||||
tokenizer.save_pretrained(args.save_path)
|
||||
|
||||
Reference in New Issue
Block a user