mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Modify eval_mm for MiniCPM-o 2.6
This commit is contained in:
@@ -18,7 +18,7 @@ def parse_args():
|
||||
|
||||
|
||||
# Only API model is accepted
|
||||
def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
|
||||
def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
|
||||
rank, world_size = get_rank_and_world_size()
|
||||
assert rank == 0 and world_size == 1
|
||||
dataset_name = dataset.dataset_name
|
||||
@@ -26,11 +26,24 @@ def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, i
|
||||
if index_set is not None:
|
||||
data = data[data['index'].isin(index_set)]
|
||||
|
||||
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
|
||||
model = supported_VLM[model_name]() if isinstance(model, str) else model
|
||||
assert getattr(model, 'is_api', False)
|
||||
if hasattr(model, 'set_dump_image'):
|
||||
model.set_dump_image(dataset.dump_image)
|
||||
|
||||
lt, indices = len(data), list(data['index'])
|
||||
structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
|
||||
|
||||
structs = []
|
||||
for i in range(lt):
|
||||
item = data.iloc[i]
|
||||
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
|
||||
assert hasattr(model, 'build_prompt')
|
||||
struct = model.build_prompt(item, dataset=dataset_name)
|
||||
else:
|
||||
struct = dataset.build_prompt(item)
|
||||
structs.append(struct)
|
||||
|
||||
# structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
|
||||
|
||||
out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
|
||||
res = {}
|
||||
@@ -55,7 +68,7 @@ def infer_data_api(work_dir, model_name, dataset, index_set=None, api_nproc=4, i
|
||||
return res
|
||||
|
||||
|
||||
def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
|
||||
def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
|
||||
dataset_name = dataset.dataset_name
|
||||
prev_file = f'{work_dir}/{model_name}_{dataset_name}_PREV.pkl'
|
||||
res = load(prev_file) if osp.exists(prev_file) else {}
|
||||
@@ -83,12 +96,13 @@ def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc
|
||||
data = data[~data['index'].isin(res)]
|
||||
lt = len(data)
|
||||
|
||||
model = supported_VLM[model_name]() if isinstance(model_name, str) else model_name
|
||||
model = supported_VLM[model_name]() if isinstance(model, str) else model
|
||||
|
||||
is_api = getattr(model, 'is_api', False)
|
||||
if is_api:
|
||||
lt, indices = len(data), list(data['index'])
|
||||
supp = infer_data_api(
|
||||
model=model,
|
||||
work_dir=work_dir,
|
||||
model_name=model_name,
|
||||
dataset=dataset,
|
||||
@@ -99,7 +113,7 @@ def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc
|
||||
res.update(supp)
|
||||
res = {k: res[k] for k in data_indices}
|
||||
dump(res, out_file)
|
||||
return model_name
|
||||
return model
|
||||
else:
|
||||
model.set_dump_image(dataset.dump_image)
|
||||
|
||||
@@ -120,7 +134,7 @@ def infer_data(model_name, work_dir, dataset, out_file, verbose=False, api_nproc
|
||||
print(response, flush=True)
|
||||
|
||||
res[idx] = response
|
||||
if (i + 1) % 20 == 0:
|
||||
if (i + 1) % 10 == 0:
|
||||
dump(res, out_file)
|
||||
|
||||
res = {k: res[k] for k in data_indices}
|
||||
@@ -149,7 +163,8 @@ def infer_data_job(model, work_dir, model_name, dataset, verbose=False, api_npro
|
||||
out_file = tmpl.format(rank)
|
||||
|
||||
model = infer_data(
|
||||
model, work_dir=work_dir, dataset=dataset, out_file=out_file, verbose=verbose, api_nproc=api_nproc)
|
||||
model=model, work_dir=work_dir, model_name=model_name, dataset=dataset,
|
||||
out_file=out_file, verbose=verbose, api_nproc=api_nproc)
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
@@ -168,4 +183,6 @@ def infer_data_job(model, work_dir, model_name, dataset, verbose=False, api_npro
|
||||
dump(data, result_file)
|
||||
for i in range(world_size):
|
||||
os.remove(tmpl.format(i))
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user