Update zero3 code and OOM FQAs (#188)

This commit is contained in:
qianyu chen
2024-05-31 12:44:21 +08:00
committed by GitHub
parent fe7184f8c9
commit 8d9c86a919
5 changed files with 131 additions and 7 deletions

View File

@@ -50,6 +50,8 @@ class TrainingArguments(transformers.TrainingArguments):
tune_llm: Optional[bool] = field(default=True)
llm_type: str = field(default="minicpm")
use_lora: Optional[bool] = field(default=False)
max_slice_nums: Optional[int] = field(default=9)
scale_resolution: Optional[int] = field(default=448)
@dataclass
@@ -272,11 +274,17 @@ def train():
rank0_print(f'llm_type={llm_type}')
# Load data
if hasattr(model.config, "slice_config"):
model.config.slice_config.max_slice_nums = training_args.max_slice_nums
model.config.slice_config.scale_resolution = training_args.scale_resolution
slice_config = model.config.slice_config.to_dict()
else:
model.config.max_slice_nums = training_args.max_slice_nums
model.config.scale_resolution = training_args.scale_resolution
slice_config = model.config.to_dict()
if hasattr(model.config, "batch_vision_input"):
batch_vision = model.config.batch_vision_input
else: