mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
192 lines
6.9 KiB
Python
192 lines
6.9 KiB
Python
from multiprocessing import Pool
|
|
import os
|
|
from typing import Callable, Iterable, Sized
|
|
|
|
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
|
|
TaskProgressColumn, TextColumn, TimeRemainingColumn)
|
|
from rich.text import Text
|
|
import os.path as osp
|
|
import portalocker
|
|
from ..smp import load, dump
|
|
|
|
|
|
class _Worker:
|
|
"""Function wrapper for ``track_progress_rich``"""
|
|
|
|
def __init__(self, func) -> None:
|
|
self.func = func
|
|
|
|
def __call__(self, inputs):
|
|
inputs, idx = inputs
|
|
if not isinstance(inputs, (tuple, list, dict)):
|
|
inputs = (inputs, )
|
|
|
|
if isinstance(inputs, dict):
|
|
return self.func(**inputs), idx
|
|
else:
|
|
return self.func(*inputs), idx
|
|
|
|
|
|
class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
|
|
"""Skip calculating remaining time for the first few times.
|
|
|
|
Args:
|
|
skip_times (int): The number of times to skip. Defaults to 0.
|
|
"""
|
|
|
|
def __init__(self, *args, skip_times=0, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.skip_times = skip_times
|
|
|
|
def render(self, task: Task) -> Text:
|
|
"""Show time remaining."""
|
|
if task.completed <= self.skip_times:
|
|
return Text('-:--:--', style='progress.remaining')
|
|
return super().render(task)
|
|
|
|
|
|
def _tasks_with_index(tasks):
|
|
"""Add index to tasks."""
|
|
for idx, task in enumerate(tasks):
|
|
yield task, idx
|
|
|
|
|
|
def track_progress_rich(func: Callable,
|
|
tasks: Iterable = tuple(),
|
|
task_num: int = None,
|
|
nproc: int = 1,
|
|
chunksize: int = 1,
|
|
description: str = 'Processing',
|
|
save=None, keys=None,
|
|
color: str = 'blue') -> list:
|
|
"""Track the progress of parallel task execution with a progress bar. The
|
|
built-in :mod:`multiprocessing` module is used for process pools and tasks
|
|
are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
|
|
|
|
Args:
|
|
func (callable): The function to be applied to each task.
|
|
tasks (Iterable or Sized): A tuple of tasks. There are several cases
|
|
for different format tasks:
|
|
- When ``func`` accepts no arguments: tasks should be an empty
|
|
tuple, and ``task_num`` must be specified.
|
|
- When ``func`` accepts only one argument: tasks should be a tuple
|
|
containing the argument.
|
|
- When ``func`` accepts multiple arguments: tasks should be a
|
|
tuple, with each element representing a set of arguments.
|
|
If an element is a ``dict``, it will be parsed as a set of
|
|
keyword-only arguments.
|
|
Defaults to an empty tuple.
|
|
task_num (int, optional): If ``tasks`` is an iterator which does not
|
|
have length, the number of tasks can be provided by ``task_num``.
|
|
Defaults to None.
|
|
nproc (int): Process (worker) number, if nuproc is 1,
|
|
use single process. Defaults to 1.
|
|
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
|
|
Defaults to 1.
|
|
description (str): The description of progress bar.
|
|
Defaults to "Process".
|
|
color (str): The color of progress bar. Defaults to "blue".
|
|
|
|
Examples:
|
|
>>> import time
|
|
|
|
>>> def func(x):
|
|
... time.sleep(1)
|
|
... return x**2
|
|
>>> track_progress_rich(func, range(10), nproc=2)
|
|
|
|
Returns:
|
|
list: The task results.
|
|
"""
|
|
if save is not None:
|
|
assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
|
|
if not osp.exists(save):
|
|
dump({}, save)
|
|
if keys is not None:
|
|
assert len(keys) == len(tasks)
|
|
|
|
if not callable(func):
|
|
raise TypeError('func must be a callable object')
|
|
if not isinstance(tasks, Iterable):
|
|
raise TypeError(
|
|
f'tasks must be an iterable object, but got {type(tasks)}')
|
|
if isinstance(tasks, Sized):
|
|
if len(tasks) == 0:
|
|
if task_num is None:
|
|
raise ValueError('If tasks is an empty iterable, '
|
|
'task_num must be set')
|
|
else:
|
|
tasks = tuple(tuple() for _ in range(task_num))
|
|
else:
|
|
if task_num is not None and task_num != len(tasks):
|
|
raise ValueError('task_num does not match the length of tasks')
|
|
task_num = len(tasks)
|
|
|
|
if nproc <= 0:
|
|
raise ValueError('nproc must be a positive number')
|
|
|
|
skip_times = nproc * chunksize if nproc > 1 else 0
|
|
prog_bar = Progress(
|
|
TextColumn('{task.description}'),
|
|
BarColumn(),
|
|
_SkipFirstTimeRemainingColumn(skip_times=skip_times),
|
|
MofNCompleteColumn(),
|
|
TaskProgressColumn(show_speed=True),
|
|
)
|
|
|
|
worker = _Worker(func)
|
|
task_id = prog_bar.add_task(
|
|
total=task_num, color=color, description=description)
|
|
tasks = _tasks_with_index(tasks)
|
|
|
|
# Use single process when nproc is 1, else use multiprocess.
|
|
with prog_bar:
|
|
if nproc == 1:
|
|
results = []
|
|
for task in tasks:
|
|
result, idx = worker(task)
|
|
results.append(worker(task)[0])
|
|
if save is not None:
|
|
with portalocker.Lock(save, timeout=5) as fh:
|
|
ans = load(save)
|
|
ans[keys[idx]] = result
|
|
|
|
if os.environ.get('VERBOSE', True):
|
|
print(keys[idx], result, flush=True)
|
|
|
|
dump(ans, save)
|
|
fh.flush()
|
|
os.fsync(fh.fileno())
|
|
|
|
prog_bar.update(task_id, advance=1, refresh=True)
|
|
else:
|
|
with Pool(nproc) as pool:
|
|
results = []
|
|
unordered_results = []
|
|
gen = pool.imap_unordered(worker, tasks, chunksize)
|
|
try:
|
|
for result in gen:
|
|
result, idx = result
|
|
unordered_results.append((result, idx))
|
|
|
|
if save is not None:
|
|
with portalocker.Lock(save, timeout=5) as fh:
|
|
ans = load(save)
|
|
ans[keys[idx]] = result
|
|
|
|
if os.environ.get('VERBOSE', False):
|
|
print(keys[idx], result, flush=True)
|
|
|
|
dump(ans, save)
|
|
fh.flush()
|
|
os.fsync(fh.fileno())
|
|
|
|
results.append(None)
|
|
prog_bar.update(task_id, advance=1, refresh=True)
|
|
except Exception as e:
|
|
prog_bar.stop()
|
|
raise e
|
|
for result, idx in unordered_results:
|
|
results[idx] = result
|
|
return results
|