zero3支持 (#273)

update for lora's modules_to_save
This commit is contained in:
qianyu chen
2024-07-15 10:32:17 +08:00
committed by GitHub
parent ef7cfa81ec
commit e002c0e6ec
3 changed files with 26 additions and 51 deletions

View File

@@ -15,19 +15,12 @@ class CPMTrainer(Trainer):
else:
labels = None
self.model.resampler.pos_embed = self.model.resampler.pos_embed.to(self.model.device)
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
if not self.args.use_lora:
outputs = self.model(data = inputs, use_cache=False)
else:
with self.model._enable_peft_forward_hooks(**inputs):
outputs = self.model.base_model(data = inputs, use_cache=False)
if not self.args.use_lora:
outputs = self.model(data = inputs, use_cache=False)
else:
if not self.args.use_lora:
outputs = self.model(data = inputs, use_cache=False)
else:
with self.model._enable_peft_forward_hooks(**inputs):
outputs = self.model.base_model(data = inputs, use_cache=False)
with self.model._enable_peft_forward_hooks(**inputs):
outputs = self.model.base_model(data = inputs, use_cache=False)
if labels is not None:
# Flatten the tokens
@@ -215,11 +208,7 @@ class CPMTrainer(Trainer):
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
self.accelerator.backward(loss)
else:
self.accelerator.backward(loss)
self.accelerator.backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps
@@ -249,20 +238,10 @@ class CPMTrainer(Trainer):
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
if self.args.use_lora:
from collections import OrderedDict
state_dict_vision = OrderedDict()
for key, values in state_dict.items():
if 'vpm' in key or 'resampler' in key or 'embed_tokens' in key:
state_dict_vision[key] = values
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
torch.save(state_dict_vision, f"{output_dir}/vpm_resampler_embedtokens.pt", )
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)