Files
MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/base.py
2025-01-21 15:34:54 +08:00

199 lines
7.6 KiB
Python

from ..smp import *
from ..dataset import img_root_map, DATASET_TYPE
from abc import abstractmethod
class BaseModel:
INTERLEAVE = False
allowed_types = ['text', 'image', 'video']
def __init__(self):
self.dump_image_func = None
def use_custom_prompt(self, dataset):
"""Whether to use custom prompt for the given dataset.
Args:
dataset (str): The name of the dataset.
Returns:
bool: Whether to use custom prompt. If True, will call `build_prompt` of the VLM to build the prompt.
Default to False.
"""
return False
@abstractmethod
def build_prompt(self, line, dataset):
"""Build custom prompts for a specific dataset. Called only if `use_custom_prompt` returns True.
Args:
line (line of pd.DataFrame): The raw input line.
dataset (str): The name of the dataset.
Returns:
str: The built message.
"""
raise NotImplementedError
def set_dump_image(self, dump_image_func):
self.dump_image_func = dump_image_func
def dump_image(self, line, dataset):
return self.dump_image_func(line)
@abstractmethod
def generate_inner(self, message, dataset=None):
raise NotImplementedError
def check_content(self, msgs):
"""Check the content type of the input. Four types are allowed: str, dict, liststr, listdict.
"""
if isinstance(msgs, str):
return 'str'
if isinstance(msgs, dict):
return 'dict'
if isinstance(msgs, list):
types = [self.check_content(m) for m in msgs]
if all(t == 'str' for t in types):
return 'liststr'
if all(t == 'dict' for t in types):
return 'listdict'
return 'unknown'
def preproc_content(self, inputs):
"""Convert the raw input messages to a list of dicts.
Args:
inputs: raw input messages.
Returns:
list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
"""
if self.check_content(inputs) == 'str':
return [dict(type='text', value=inputs)]
elif self.check_content(inputs) == 'dict':
assert 'type' in inputs and 'value' in inputs
return [inputs]
elif self.check_content(inputs) == 'liststr':
res = []
for s in inputs:
mime, pth = parse_file(s)
if mime is None or mime == 'unknown':
res.append(dict(type='text', value=s))
else:
res.append(dict(type=mime.split('/')[0], value=pth))
return res
elif self.check_content(inputs) == 'listdict':
for item in inputs:
assert 'type' in item and 'value' in item
mime, s = parse_file(item['value'])
if mime is None:
assert item['type'] == 'text'
else:
assert mime.split('/')[0] == item['type']
item['value'] = s
return inputs
else:
return None
def generate(self, message, dataset=None):
"""Generate the output message.
Args:
message (list[dict]): The input message.
dataset (str, optional): The name of the dataset. Defaults to None.
Returns:
str: The generated message.
"""
assert self.check_content(message) in ['str', 'dict', 'liststr', 'listdict'], f'Invalid input type: {message}'
message = self.preproc_content(message)
assert message is not None and self.check_content(message) == 'listdict'
for item in message:
assert item['type'] in self.allowed_types, f'Invalid input type: {item["type"]}'
return self.generate_inner(message, dataset)
def chat(self, messages, dataset=None):
"""The main function for multi-turn chatting. Will call `chat_inner` with the preprocessed input messages."""
assert hasattr(self, 'chat_inner'), 'The API model should has the `chat_inner` method. '
for msg in messages:
assert isinstance(msg, dict) and 'role' in msg and 'content' in msg, msg
assert self.check_content(msg['content']) in ['str', 'dict', 'liststr', 'listdict'], msg
msg['content'] = self.preproc_content(msg['content'])
while len(messages):
try:
return self.chat_inner(messages, dataset=dataset)
except Exception as e:
logging.info(f'{type(e)}: {e}')
messages = messages[1:]
while len(messages) and messages[0]['role'] != 'user':
messages = messages[1:]
continue
return 'Chat Mode: Failed with all possible conversation turns.'
def message_to_promptimg(self, message, dataset=None):
assert not self.INTERLEAVE
model_name = self.__class__.__name__
warnings.warn(
f'Model {model_name} does not support interleaved input. '
'Will use the first image and aggregated texts as prompt. ')
num_images = len([x for x in message if x['type'] == 'image'])
if num_images == 0:
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
image = None
else:
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
images = [x['value'] for x in message if x['type'] == 'image']
if 'BLINK' == dataset:
image = concat_images_vlmeval(images, target_size=512)
else:
image = images[0]
return prompt, image
def message_to_promptvideo(self, message):
if self.VIDEO_LLM:
num_videos = len([x for x in message if x['type'] == 'video'])
if num_videos == 0:
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
video = None
else:
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
video = [x['value'] for x in message if x['type'] == 'video'][0]
return prompt, video
else:
logging.critical('Model does not support video input.')
raise NotImplementedError
def message_to_promptvideo_withrole(self, message, dataset=None):
if self.VIDEO_LLM:
system, user, assistant, video_list = '', '', '', []
for msg in message:
if msg['type'] == 'text':
if 'role' in msg and msg['role'] == 'system':
system += msg['value']
elif 'role' in msg and msg['role'] == 'assistant':
assistant += msg['value']
else:
user += msg['value']
elif msg['type'] == 'video':
video_list.append(msg['value'])
question = {
'system': system,
'user': user,
'assistant': assistant
}
if assistant == '':
if listinstr(['MCQ'], DATASET_TYPE(dataset)):
question['assistant'] = 'Best Option: ('
else:
del question['assistant']
if len(video_list) > 1:
print('VLMEvalKit only support single video as input, take first video as input')
video = video_list[0]
return question, video
else:
logging.critical('Model does not support video input.')
raise NotImplementedError