update lora finetune inference bug (#224)

This commit is contained in:
qianyu chen
2024-06-07 18:00:22 +08:00
committed by GitHub
parent 31eaa26ee1
commit 9bd93a281c
6 changed files with 61 additions and 17 deletions

View File

@@ -51,7 +51,6 @@ class TrainingArguments(transformers.TrainingArguments):
llm_type: str = field(default="minicpm")
use_lora: Optional[bool] = field(default=False)
max_slice_nums: Optional[int] = field(default=9)
scale_resolution: Optional[int] = field(default=448)
@dataclass
@@ -270,17 +269,15 @@ def train():
)
model = get_peft_model(model, lora_config)
model.base_model.resampler.requires_grad_(True)
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
if training_args.tune_vision:
model.base_model.vpm.requires_grad_(True)
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
rank0_print(get_parameter_number(model))
llm_type = training_args.llm_type
if llm_type == "llama3":
tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
rank0_print(f'llm_type={llm_type}')
@@ -288,11 +285,9 @@ def train():
# Load data
if hasattr(model.config, "slice_config"):
model.config.slice_config.max_slice_nums = training_args.max_slice_nums
model.config.slice_config.scale_resolution = training_args.scale_resolution
slice_config = model.config.slice_config.to_dict()
else:
model.config.max_slice_nums = training_args.max_slice_nums
model.config.scale_resolution = training_args.scale_resolution
slice_config = model.config.to_dict()
if hasattr(model.config, "batch_vision_input"):