mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
183 lines
5.8 KiB
Python
183 lines
5.8 KiB
Python
import torch
|
|
import torch.distributed as dist
|
|
from vlmeval.config import supported_VLM
|
|
from vlmeval.utils import track_progress_rich
|
|
from vlmeval.smp import *
|
|
|
|
FAIL_MSG = 'Failed to obtain answer via API.'
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--data', type=str, nargs='+', required=True)
|
|
parser.add_argument('--model', type=str, nargs='+', required=True)
|
|
parser.add_argument('--nproc', type=int, default=4, required=True)
|
|
parser.add_argument('--verbose', action='store_true')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def chat_mt(model, messages, dataset_name):
|
|
assert len(messages) % 2 == 0
|
|
nturn = len(messages) // 2
|
|
utter_stack = []
|
|
predictions = []
|
|
|
|
for i in range(nturn):
|
|
utter = messages[2 * i]
|
|
utter_stack.append(utter)
|
|
try:
|
|
resp = model.chat(utter_stack, dataset=dataset_name)
|
|
utter_stack.append(dict(role='assistant', content=resp))
|
|
except Exception as e:
|
|
resp = FAIL_MSG + str(e)
|
|
utter_stack.append(dict(role='assistant', content=resp))
|
|
predictions.append(resp)
|
|
return predictions
|
|
|
|
|
|
# Only API model is accepted
|
|
def infer_data_api(model, work_dir, model_name, dataset, index_set=None, api_nproc=4, ignore_failed=False):
|
|
rank, world_size = get_rank_and_world_size()
|
|
assert rank == 0 and world_size == 1
|
|
dataset_name = dataset.dataset_name
|
|
data = dataset.data
|
|
if index_set is not None:
|
|
data = data[data['index'].isin(index_set)]
|
|
|
|
model = supported_VLM[model_name]() if isinstance(model, str) else model
|
|
assert getattr(model, 'is_api', False)
|
|
assert hasattr(model, 'chat_inner')
|
|
|
|
lt, indices = len(data), list(data['index'])
|
|
structs = [dataset.build_prompt(data.iloc[i]) for i in range(lt)]
|
|
|
|
out_file = f'{work_dir}/{model_name}_{dataset_name}_supp.pkl'
|
|
res = {}
|
|
if osp.exists(out_file):
|
|
res = load(out_file)
|
|
if ignore_failed:
|
|
res = {k: v for k, v in res.items() if FAIL_MSG not in v}
|
|
|
|
structs = [s for i, s in zip(indices, structs) if i not in res]
|
|
indices = [i for i in indices if i not in res]
|
|
|
|
structs = [dict(model=model, messages=struct, dataset_name=dataset_name) for struct in structs]
|
|
|
|
if len(structs):
|
|
track_progress_rich(chat_mt, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
|
|
|
|
res = load(out_file)
|
|
if index_set is not None:
|
|
res = {k: v for k, v in res.items() if k in index_set}
|
|
os.remove(out_file)
|
|
return res
|
|
|
|
|
|
def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
|
|
dataset_name = dataset.dataset_name
|
|
res = {}
|
|
if osp.exists(out_file):
|
|
res.update(load(out_file))
|
|
|
|
rank, world_size = get_rank_and_world_size()
|
|
sheet_indices = list(range(rank, len(dataset), world_size))
|
|
lt = len(sheet_indices)
|
|
data = dataset.data.iloc[sheet_indices]
|
|
data_indices = [i for i in data['index']]
|
|
|
|
# If finished, will exit without building the model
|
|
all_finished = True
|
|
for i in range(lt):
|
|
idx = data.iloc[i]['index']
|
|
if idx not in res:
|
|
all_finished = False
|
|
if all_finished:
|
|
res = {k: res[k] for k in data_indices}
|
|
dump(res, out_file)
|
|
return
|
|
|
|
# Data need to be inferred
|
|
data = data[~data['index'].isin(res)]
|
|
lt = len(data)
|
|
|
|
model = supported_VLM[model_name]() if isinstance(model, str) else model
|
|
assert hasattr(model, 'chat_inner')
|
|
|
|
is_api = getattr(model, 'is_api', False)
|
|
if is_api:
|
|
lt, indices = len(data), list(data['index'])
|
|
supp = infer_data_api(
|
|
model=model,
|
|
work_dir=work_dir,
|
|
model_name=model_name,
|
|
dataset=dataset,
|
|
index_set=set(indices),
|
|
api_nproc=api_nproc)
|
|
for idx in indices:
|
|
assert idx in supp
|
|
res.update(supp)
|
|
res = {k: res[k] for k in data_indices}
|
|
dump(res, out_file)
|
|
return model
|
|
else:
|
|
model.set_dump_image(dataset.dump_image)
|
|
|
|
for i in tqdm(range(lt)):
|
|
idx = data.iloc[i]['index']
|
|
if idx in res:
|
|
continue
|
|
|
|
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
|
|
struct = model.build_prompt(data.iloc[i], dataset=dataset_name)
|
|
else:
|
|
struct = dataset.build_prompt(data.iloc[i])
|
|
|
|
response = chat_mt(model, struct, dataset_name)
|
|
torch.cuda.empty_cache()
|
|
|
|
if verbose:
|
|
print(response, flush=True)
|
|
|
|
res[idx] = response
|
|
if (i + 1) % 20 == 0:
|
|
dump(res, out_file)
|
|
|
|
res = {k: res[k] for k in data_indices}
|
|
dump(res, out_file)
|
|
return model
|
|
|
|
|
|
# A wrapper for infer_data, do the pre & post processing
|
|
def infer_data_job_mt(model, work_dir, model_name, dataset, verbose=False, api_nproc=4, ignore_failed=False):
|
|
rank, world_size = get_rank_and_world_size()
|
|
dataset_name = dataset.dataset_name
|
|
result_file = osp.join(work_dir, f'{model_name}_{dataset_name}.tsv')
|
|
|
|
tmpl = osp.join(work_dir, '{}' + f'{world_size}_{dataset_name}.pkl')
|
|
out_file = tmpl.format(rank)
|
|
|
|
model = infer_data(
|
|
model=model, model_name=model_name,work_dir=work_dir, dataset=dataset,
|
|
out_file=out_file, verbose=verbose, api_nproc=api_nproc)
|
|
if world_size > 1:
|
|
dist.barrier()
|
|
|
|
if rank == 0:
|
|
data_all = {}
|
|
for i in range(world_size):
|
|
data_all.update(load(tmpl.format(i)))
|
|
|
|
data = dataset.data
|
|
for x in data['index']:
|
|
assert x in data_all
|
|
|
|
data['prediction'] = [data_all[x] for x in data['index']]
|
|
if 'image' in data:
|
|
data.pop('image')
|
|
|
|
dump(data, result_file)
|
|
for i in range(world_size):
|
|
os.remove(tmpl.format(i))
|
|
return model
|