mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
184 lines
7.0 KiB
Python
184 lines
7.0 KiB
Python
import torch
|
|
import torch.distributed as dist
|
|
from vlmeval.config import supported_VLM
|
|
from vlmeval.utils import track_progress_rich
|
|
from vlmeval.smp import *
|
|
|
|
FAIL_MSG = 'Failed to obtain answer via API.'
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--data', type=str, nargs='+', required=True)
|
|
parser.add_argument('--model', type=str, nargs='+', required=True)
|
|
parser.add_argument('--nproc', type=int, default=4, required=True)
|
|
parser.add_argument('--verbose', action='store_true')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
# Only API model is accepted
|
|
def infer_data_api(model, work_dir, model_name, dataset, samples_dict={}, api_nproc=4):
|
|
rank, world_size = get_rank_and_world_size()
|
|
assert rank == 0 and world_size == 1
|
|
dataset_name = dataset.dataset_name
|
|
model = supported_VLM[model_name]() if isinstance(model, str) else model
|
|
assert getattr(model, 'is_api', False)
|
|
|
|
indices = list(samples_dict.keys())
|
|
structs = [dataset.build_prompt(samples_dict[idx], video_llm=getattr(model, 'VIDEO_LLM', False)) for idx in indices]
|
|
|
|
packstr = 'pack' if getattr(dataset, 'pack', False) else 'nopack'
|
|
if dataset.nframe > 0:
|
|
out_file = f'{work_dir}/{model_name}_{dataset_name}_{dataset.nframe}frame_{packstr}_supp.pkl'
|
|
else:
|
|
out_file = f'{work_dir}/{model_name}_{dataset_name}_{dataset.fps}fps_{packstr}_supp.pkl'
|
|
res = load(out_file) if osp.exists(out_file) else {}
|
|
|
|
structs = [s for i, s in zip(indices, structs) if i not in res or res[i] == FAIL_MSG]
|
|
indices = [i for i in indices if i not in res or res[i] == FAIL_MSG]
|
|
|
|
gen_func = model.generate
|
|
structs = [dict(message=struct, dataset=dataset_name) for struct in structs]
|
|
|
|
if len(structs):
|
|
track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices)
|
|
|
|
res = load(out_file)
|
|
return res
|
|
|
|
|
|
def infer_data(model, model_name, work_dir, dataset, out_file, verbose=False, api_nproc=4):
|
|
res = load(out_file) if osp.exists(out_file) else {}
|
|
rank, world_size = get_rank_and_world_size()
|
|
dataset_name = dataset.dataset_name
|
|
|
|
sample_indices = list(dataset.videos) if getattr(dataset, 'pack', False) else list(dataset.data['index'])
|
|
samples = list(dataset.videos) if getattr(dataset, 'pack', False) else list(range(len(dataset.data)))
|
|
sample_map = {i: s for i, s in zip(sample_indices, samples)}
|
|
|
|
sample_indices_sub = sample_indices[rank::world_size]
|
|
if np.all([idx in res for idx in sample_indices_sub]):
|
|
return model
|
|
sample_indices_subrem = [x for x in sample_indices_sub if x not in res]
|
|
|
|
model = supported_VLM[model_name]() if isinstance(model, str) else model
|
|
|
|
is_api = getattr(model, 'is_api', False)
|
|
if is_api:
|
|
assert world_size == 1
|
|
supp = infer_data_api(
|
|
model=model,
|
|
work_dir=work_dir,
|
|
model_name=model_name,
|
|
dataset=dataset,
|
|
samples_dict={k: sample_map[k] for k in sample_indices_subrem},
|
|
api_nproc=api_nproc)
|
|
for k in sample_indices_subrem:
|
|
assert k in supp
|
|
res.update(supp)
|
|
dump(res, out_file)
|
|
return model
|
|
|
|
assert not getattr(dataset, 'pack', False), 'Current model not supported pack mode!'
|
|
for i, idx in tqdm(enumerate(sample_indices_subrem)):
|
|
if idx in res:
|
|
continue
|
|
if getattr(model, 'nframe', None) is not None and getattr(model, 'nframe', 0) > 0:
|
|
if dataset.nframe > 0:
|
|
if getattr(model, 'nframe', 0) != dataset.nframe:
|
|
print(f'{model_name} is a video-llm model, nframe is set to {dataset.nframe}, not using default')
|
|
setattr(model, 'nframe', dataset.nframe)
|
|
elif getattr(model, 'fps', 0) == 0:
|
|
raise ValueError(f'fps is not suitable for {model_name}')
|
|
else:
|
|
setattr(model, 'nframe', None)
|
|
if getattr(model, 'fps', None) is not None and getattr(model, 'fps', 0) > 0:
|
|
if dataset.fps > 0:
|
|
if getattr(model, 'fps', 0) != dataset.fps:
|
|
print(f'{model_name} is a video-llm model, fps is set to {dataset.fps}, not using default')
|
|
setattr(model, 'fps', dataset.fps)
|
|
elif getattr(model, 'nframe', 0) == 0:
|
|
raise ValueError(f'nframe is not suitable for {model_name}')
|
|
else:
|
|
setattr(model, 'fps', None)
|
|
if 'SUB_DATASET' in dataset.data.iloc[sample_map[idx]]:
|
|
dataset_name = dataset.data.iloc[sample_map[idx]]['SUB_DATASET']
|
|
if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name):
|
|
if dataset.nframe == 0:
|
|
raise ValueError(f'nframe must be set for custom prompt, fps is not suitable for {model_name}')
|
|
struct = model.build_prompt(
|
|
dataset.data.iloc[sample_map[idx]], dataset=dataset, video_llm=getattr(model, 'VIDEO_LLM', False)
|
|
)
|
|
else:
|
|
struct = dataset.build_prompt(
|
|
sample_map[idx], video_llm=getattr(model, 'VIDEO_LLM', False)
|
|
)
|
|
response = model.generate(message=struct, dataset=dataset_name)
|
|
torch.cuda.empty_cache()
|
|
|
|
if verbose:
|
|
print(response, flush=True)
|
|
|
|
res[idx] = response
|
|
if (i + 1) % 20 == 0:
|
|
dump(res, out_file)
|
|
|
|
res = {k: res[k] for k in sample_indices_sub}
|
|
dump(res, out_file)
|
|
return model
|
|
|
|
|
|
# A wrapper for infer_data, do the pre & post processing
|
|
def infer_data_job_video(
|
|
model,
|
|
work_dir,
|
|
model_name,
|
|
dataset,
|
|
result_file_name,
|
|
verbose=False,
|
|
api_nproc=4):
|
|
|
|
dataset_name = dataset.dataset_name
|
|
rank, world_size = get_rank_and_world_size()
|
|
result_file = osp.join(work_dir, result_file_name)
|
|
# Dump Predictions to Prev File if result file exists
|
|
if osp.exists(result_file):
|
|
return model
|
|
|
|
tmpl = osp.join(work_dir, '{}' + f'{world_size}_{osp.splitext(result_file_name)[0]}.pkl')
|
|
out_file = tmpl.format(rank)
|
|
|
|
model = infer_data(
|
|
model=model,
|
|
model_name=model_name,
|
|
work_dir=work_dir,
|
|
dataset=dataset,
|
|
out_file=out_file,
|
|
verbose=verbose,
|
|
api_nproc=api_nproc)
|
|
|
|
if world_size > 1:
|
|
dist.barrier()
|
|
|
|
if rank == 0:
|
|
data_all = {}
|
|
for i in range(world_size):
|
|
data_all.update(load(tmpl.format(i)))
|
|
|
|
meta = dataset.data
|
|
if dataset_name == 'MMBench-Video' and getattr(dataset, 'pack', False):
|
|
meta, vstats = dataset.load_pack_answers(data_all)
|
|
print(f'Statitics of Pack Video Inference: {vstats}')
|
|
else:
|
|
for x in meta['index']:
|
|
assert x in data_all
|
|
meta['prediction'] = [str(data_all[x]) for x in meta['index']]
|
|
if 'image' in meta:
|
|
meta.pop('image')
|
|
|
|
dump(meta, result_file)
|
|
for i in range(world_size):
|
|
os.remove(tmpl.format(i))
|
|
return model
|