Modify eval_mm for MiniCPM-o 2.6

This commit is contained in:
Poppy Xu
2025-01-21 15:34:54 +08:00
parent ec68cefc17
commit d8f382e157
82 changed files with 14279 additions and 843 deletions

View File

@@ -29,15 +29,15 @@ def chat_mt(model, messages, dataset_name):
try:
resp = model.chat(utter_stack, dataset=dataset_name)
utter_stack.append(dict(role='assistant', content=resp))
except:
resp = FAIL_MSG
except Exception as e:
resp = FAIL_MSG + str(e)
utter_stack.append(dict(role='assistant', content=resp))
predictions.append(resp)
return predictions
# Only API model is accepted
def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
rank, world_size = get_rank_and_world_size()
assert rank == 0 and world_size == 1
dataset_name = dataset.dataset_name
@@ -45,7 +45,7 @@ def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, i
if index_set is not None:
data = data[data['index'].isin(index_set)]
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
model = supported_VLM[model_name]() if isinstance(model, str) else model
assert getattr(model, 'is_api', False)
assert hasattr(model, 'chat_inner')
@@ -74,7 +74,7 @@ def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, i
return res
def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
dataset_name = dataset.dataset_name
res = {}
if osp.exists(out_file):
@@ -101,13 +101,14 @@ def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc
data = data[~data['index'].isin(res)]
lt = len(data)
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
model = supported_VLM[model_name]() if isinstance(model, str) else model
assert hasattr(model, 'chat_inner')
is_api = getattr(model, 'is_api', False)
if is_api:
lt, indices = len(data), list(data['index'])
supp = infer_data_api(
model=model,
work_dir=work_dir,
model_name=model_name,
dataset=dataset,
@@ -118,7 +119,7 @@ def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc
res.update(supp)
res = {k: res[k] for k in data_indices}
dump(res, out_file)
return model_name
return model
else:
model.set_dump_image(dataset.dump_image)
@@ -157,7 +158,8 @@ def infer_data_job_mt(model, work_dir, model_name, dataset, verbose=False, api_n
out_file = tmpl.format(rank)
model = infer_data(
model, work_dir=work_dir, dataset=dataset, out_file=out_file, verbose=verbose, api_nproc=api_nproc)
model=model, model_name=model_name,work_dir=work_dir, dataset=dataset,
out_file=out_file, verbose=verbose, api_nproc=api_nproc)
if world_size > 1:
dist.barrier()