mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
@@ -15,19 +15,12 @@ class CPMTrainer(Trainer):
|
||||
else:
|
||||
labels = None
|
||||
self.model.resampler.pos_embed = self.model.resampler.pos_embed.to(self.model.device)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
|
||||
if not self.args.use_lora:
|
||||
outputs = self.model(data = inputs, use_cache=False)
|
||||
else:
|
||||
with self.model._enable_peft_forward_hooks(**inputs):
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
|
||||
if not self.args.use_lora:
|
||||
outputs = self.model(data = inputs, use_cache=False)
|
||||
else:
|
||||
if not self.args.use_lora:
|
||||
outputs = self.model(data = inputs, use_cache=False)
|
||||
else:
|
||||
with self.model._enable_peft_forward_hooks(**inputs):
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
with self.model._enable_peft_forward_hooks(**inputs):
|
||||
outputs = self.model.base_model(data = inputs, use_cache=False)
|
||||
|
||||
if labels is not None:
|
||||
# Flatten the tokens
|
||||
@@ -215,11 +208,7 @@ class CPMTrainer(Trainer):
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
|
||||
self.accelerator.backward(loss)
|
||||
else:
|
||||
self.accelerator.backward(loss)
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return loss.detach() / self.args.gradient_accumulation_steps
|
||||
|
||||
@@ -249,20 +238,10 @@ class CPMTrainer(Trainer):
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
if self.args.use_lora:
|
||||
from collections import OrderedDict
|
||||
state_dict_vision = OrderedDict()
|
||||
for key, values in state_dict.items():
|
||||
if 'vpm' in key or 'resampler' in key or 'embed_tokens' in key:
|
||||
state_dict_vision[key] = values
|
||||
self.model.save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
torch.save(state_dict_vision, f"{output_dir}/vpm_resampler_embedtokens.pt", )
|
||||
else:
|
||||
self.model.save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
|
||||
self.model.save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
Reference in New Issue
Block a user