mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 18:29:18 +08:00
Modify eval_mm for MiniCPM-o 2.6
This commit is contained in:
@@ -5,13 +5,13 @@ import csv
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
import copy as cp
|
||||
import random as rd
|
||||
import requests
|
||||
import shutil
|
||||
import subprocess
|
||||
import warnings
|
||||
import logging
|
||||
import pandas as pd
|
||||
from collections import OrderedDict, defaultdict
|
||||
from multiprocessing import Pool, current_process
|
||||
@@ -21,8 +21,14 @@ import matplotlib.pyplot as plt
|
||||
from tabulate import tabulate
|
||||
from json import JSONDecoder
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from huggingface_hub.utils._cache_manager import _scan_cached_repo
|
||||
from sty import fg, bg, ef, rs
|
||||
|
||||
|
||||
def modelscope_flag_set():
|
||||
return os.environ.get('VLMEVALKIT_USE_MODELSCOPE', None) in ['1', 'True']
|
||||
|
||||
|
||||
def process_punctuation(inText):
|
||||
import re
|
||||
outText = inText
|
||||
@@ -71,26 +77,30 @@ def bincount(lst):
|
||||
bins[item] += 1
|
||||
return bins
|
||||
|
||||
def get_cache_path(repo_id, branch=None):
|
||||
hf_cache_info = scan_cache_dir()
|
||||
repos = list(hf_cache_info.repos)
|
||||
repo = None
|
||||
for r in repos:
|
||||
if r.repo_id == repo_id:
|
||||
repo = r
|
||||
break
|
||||
if repo is None:
|
||||
def get_cache_path(repo_id, branch='main', repo_type='datasets'):
|
||||
try:
|
||||
if modelscope_flag_set():
|
||||
from modelscope.hub.file_download import create_temporary_directory_and_cache
|
||||
if repo_type == 'datasets':
|
||||
repo_type = 'dataset'
|
||||
_, cache = create_temporary_directory_and_cache(model_id=repo_id, repo_type=repo_type)
|
||||
cache_path = cache.get_root_location()
|
||||
return cache_path
|
||||
else:
|
||||
from .file import HFCacheRoot
|
||||
cache_path = HFCacheRoot()
|
||||
org, repo_name = repo_id.split('/')
|
||||
repo_path = Path(osp.join(cache_path, f'{repo_type}--{org}--{repo_name}/'))
|
||||
hf_cache_info = _scan_cached_repo(repo_path=repo_path)
|
||||
revs = {r.refs: r for r in hf_cache_info.revisions}
|
||||
if branch is not None:
|
||||
revs = {refs: r for refs, r in revs.items() if branch in refs}
|
||||
rev2keep = max(revs.values(), key=lambda r: r.last_modified)
|
||||
return str(rev2keep.snapshot_path)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.warning(f'{type(e)}: {e}')
|
||||
return None
|
||||
revs = list(repo.revisions)
|
||||
if branch is not None:
|
||||
revs = [r for r in revs if r.refs == frozenset({branch})]
|
||||
rev2keep, last_modified = None, 0
|
||||
for rev in revs:
|
||||
if rev.last_modified > last_modified:
|
||||
rev2keep, last_modified = rev, rev.last_modified
|
||||
if rev2keep is None:
|
||||
return None
|
||||
return str(rev2keep.snapshot_path)
|
||||
|
||||
def proxy_set(s):
|
||||
import os
|
||||
@@ -126,14 +136,47 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def timestr(second=True, minute=False):
|
||||
s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]
|
||||
if second:
|
||||
def timestr(granularity='second'):
|
||||
s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
||||
assert granularity in ['second', 'minute', 'hour', 'day']
|
||||
if granularity == 'second':
|
||||
return s
|
||||
elif minute:
|
||||
elif granularity == 'minute':
|
||||
return s[:-2]
|
||||
else:
|
||||
elif granularity == 'hour':
|
||||
return s[:-4]
|
||||
elif granularity == 'day':
|
||||
return s[:-6]
|
||||
|
||||
def _minimal_ext_cmd(cmd, cwd=None):
|
||||
env = {}
|
||||
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
||||
v = os.environ.get(k)
|
||||
if v is not None:
|
||||
env[k] = v
|
||||
env['LANGUAGE'] = 'C'
|
||||
env['LANG'] = 'C'
|
||||
env['LC_ALL'] = 'C'
|
||||
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env, cwd=cwd).communicate()[0]
|
||||
return out
|
||||
|
||||
def githash(fallback='unknown', digits=8):
|
||||
if digits is not None and not isinstance(digits, int):
|
||||
raise TypeError('digits must be None or an integer')
|
||||
try:
|
||||
import vlmeval
|
||||
except ImportError as e:
|
||||
import logging
|
||||
logging.error(f'ImportError: {str(e)}')
|
||||
return fallback
|
||||
try:
|
||||
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'], cwd=vlmeval.__path__[0])
|
||||
sha = out.strip().decode('ascii')
|
||||
if digits is not None:
|
||||
sha = sha[:digits]
|
||||
except OSError:
|
||||
sha = fallback
|
||||
return sha
|
||||
|
||||
def dict_merge(dct, merge_dct):
|
||||
for k, _ in merge_dct.items():
|
||||
@@ -152,17 +195,21 @@ def run_command(cmd):
|
||||
return subprocess.check_output(cmd).decode()
|
||||
|
||||
def load_env():
|
||||
logger = logging.getLogger('LOAD_ENV')
|
||||
import logging
|
||||
logging.basicConfig(
|
||||
format='[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S')
|
||||
|
||||
try:
|
||||
import vlmeval
|
||||
except ImportError:
|
||||
logger.error('VLMEval is not installed. Failed to import environment variables from .env file. ')
|
||||
logging.error('VLMEval is not installed. Failed to import environment variables from .env file. ')
|
||||
return
|
||||
pth = osp.realpath(vlmeval.__path__[0])
|
||||
pth = osp.join(pth, '../.env')
|
||||
pth = osp.realpath(pth)
|
||||
if not osp.exists(pth):
|
||||
logger.error(f'Did not detect the .env file at {pth}, failed to load. ')
|
||||
logging.error(f'Did not detect the .env file at {pth}, failed to load. ')
|
||||
return
|
||||
|
||||
from dotenv import dotenv_values
|
||||
@@ -170,7 +217,7 @@ def load_env():
|
||||
for k, v in values.items():
|
||||
if v is not None and len(v):
|
||||
os.environ[k] = v
|
||||
logger.info(f'API Keys successfully loaded from {pth}')
|
||||
logging.info(f'API Keys successfully loaded from {pth}')
|
||||
|
||||
def pip_install_robust(package):
|
||||
import sys
|
||||
@@ -214,3 +261,31 @@ def extract_json_objects(text, decoder=JSONDecoder()):
|
||||
pos = match + index
|
||||
except ValueError:
|
||||
pos = match + 1
|
||||
|
||||
|
||||
def get_gpu_memory():
|
||||
import subprocess
|
||||
try:
|
||||
command = "nvidia-smi --query-gpu=memory.free --format=csv"
|
||||
memory_free_info = subprocess.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
|
||||
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
||||
return memory_free_values
|
||||
except Exception as e:
|
||||
print(f'{type(e)}: {str(e)}')
|
||||
return []
|
||||
|
||||
|
||||
def auto_split_flag():
|
||||
flag = os.environ.get('AUTO_SPLIT', '0')
|
||||
if flag == '1':
|
||||
return True
|
||||
_, world_size = get_rank_and_world_size()
|
||||
try:
|
||||
import torch
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count > world_size and device_count % world_size == 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user