mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 10:19:18 +08:00
First commit
This commit is contained in:
153
omnilmm/train/train_utils.py
Normal file
153
omnilmm/train/train_utils.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import os
|
||||
import gc
|
||||
import copy
|
||||
import time
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
import transformers
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Dict, Optional, Sequence
|
||||
from omnilmm import conversation as conversation_lib
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
DEFAULT_IMAGE_TOKEN = "<image>"
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = "<im_end>"
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
) for text in strings
|
||||
]
|
||||
input_ids = labels = [
|
||||
tokenized.input_ids[0] for tokenized in tokenized_list
|
||||
]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
||||
for tokenized in tokenized_list
|
||||
]
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def omni_preprocess(sources,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
generation=False):
|
||||
system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.'
|
||||
ignore_index = -100
|
||||
|
||||
response_template = '\n<|assistant|>\n'
|
||||
instruction_template = '\n<|user|>\n'
|
||||
response_token_ids = tokenizer.encode(
|
||||
response_template, add_special_tokens=False)
|
||||
instruction_token_ids = tokenizer.encode(
|
||||
instruction_template, add_special_tokens=False)
|
||||
|
||||
batch_input_ids = []
|
||||
batch_labels = []
|
||||
for i in range(len(sources)):
|
||||
new_source = []
|
||||
prev_role = 'unexpect'
|
||||
for conv_turn in sources[i]:
|
||||
role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role']
|
||||
content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content']
|
||||
|
||||
role = 'user' if role == 'human' else role
|
||||
role = 'assistant' if role == 'gpt' else role
|
||||
|
||||
assert role in ['user', 'assistant']
|
||||
assert role != prev_role, f'role={role}, prev_role={prev_role}'
|
||||
prev_role = role
|
||||
|
||||
new_turn = {
|
||||
'role': role,
|
||||
'content': content
|
||||
}
|
||||
new_source.append(new_turn)
|
||||
if new_source[0]['role'] != 'system':
|
||||
new_source.insert(0, {'role': 'system', 'content': system_content})
|
||||
|
||||
# TODO: this automatically add '\n' to the end
|
||||
res_text = tokenizer.apply_chat_template(
|
||||
new_source, tokenize=False, add_generation_prompt=generation)
|
||||
if not generation:
|
||||
res_text = res_text.strip()
|
||||
|
||||
conversations_tokenized = _tokenize_fn([res_text], tokenizer)
|
||||
res_input_ids = conversations_tokenized["input_ids"][0]
|
||||
|
||||
# since labels and input_ids are reference towards the same object
|
||||
res_labels = copy.deepcopy(conversations_tokenized["labels"][0])
|
||||
|
||||
response_token_ids_idxs = []
|
||||
human_token_ids_idxs = []
|
||||
|
||||
for assistant_idx in np.where(res_labels == response_token_ids[0])[0]:
|
||||
# find the indexes of the start of a response.
|
||||
if (response_token_ids == res_labels[assistant_idx: assistant_idx + len(
|
||||
response_token_ids)].tolist()
|
||||
):
|
||||
response_token_ids_idxs.append(
|
||||
assistant_idx + len(response_token_ids))
|
||||
|
||||
if len(response_token_ids_idxs) == 0:
|
||||
warnings.warn(
|
||||
f"Could not find response key `{response_template}` in the "
|
||||
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
|
||||
f'Raw text is @===>{res_text}<===@'
|
||||
f'Raw source is @===>{new_source}<===@'
|
||||
f"This instance will be ignored in loss calculation. "
|
||||
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
||||
)
|
||||
res_labels[:] = ignore_index
|
||||
|
||||
human_token_ids = instruction_token_ids
|
||||
for human_idx in np.where(res_labels == human_token_ids[0])[0]:
|
||||
# find the indexes of the start of a human answer.
|
||||
if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist():
|
||||
human_token_ids_idxs.append(human_idx)
|
||||
|
||||
if len(human_token_ids_idxs) == 0:
|
||||
warnings.warn(
|
||||
f"Could not find instruction key `{instruction_template}` in the "
|
||||
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
|
||||
f'Raw text is @===>{res_text}<===@'
|
||||
f'Raw source is @===>{new_source}<===@'
|
||||
f"This instance will be ignored in loss calculation. "
|
||||
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
||||
)
|
||||
res_labels[:] = ignore_index
|
||||
|
||||
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
|
||||
# Make pytorch loss function ignore all non response tokens
|
||||
if idx != 0:
|
||||
res_labels[start:end] = ignore_index
|
||||
else:
|
||||
res_labels[:end] = ignore_index
|
||||
|
||||
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
|
||||
res_labels[human_token_ids_idxs[-1]:] = ignore_index
|
||||
|
||||
batch_input_ids.append(res_input_ids)
|
||||
batch_labels.append(res_labels)
|
||||
|
||||
return dict(input_ids=batch_input_ids, labels=batch_labels)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user