update finetune sh

This commit is contained in:
cjm
2024-05-23 10:24:33 +08:00
parent 296a0c5c1e
commit 7868a4de27
4 changed files with 11 additions and 6 deletions

View File

@@ -46,6 +46,7 @@ class TrainingArguments(transformers.TrainingArguments):
)
tune_vision: Optional[bool] = field(default=True)
tune_llm: Optional[bool] = field(default=True)
llm_type: str = field(default="minicpm")
def rank0_print(*args):
@@ -166,10 +167,11 @@ def train():
model.llm.requires_grad_(False)
rank0_print(get_parameter_number(model))
llm_type = "minicpm"
if "llama3" in model.name_or_path.lower():
llm_type = training_args.llm_type
if llm_type == "llama3":
tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
llm_type = "llama3"
rank0_print(f'llm_type={llm_type}')
# Load data
if hasattr(model.config, "slice_config"):