Update to MiniCPM-V 2.6

This commit is contained in:
yiranyyu
2024-08-06 12:26:49 +08:00
parent 1cb882d473
commit b1a15299e6
28 changed files with 3692 additions and 191 deletions

76
chat.py
View File

@@ -183,13 +183,87 @@ class MiniCPMV2_5:
)
return answer
class MiniCPMV2_6:
def __init__(self, model_path, multi_gpus=False) -> None:
print('torch_version:', torch.__version__)
if multi_gpus: # inference on multi-gpus
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
with init_empty_weights():
model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
device_id = device_map["llm.model.embed_tokens"]
device_map["llm.lm_head"] = device_id # first and last layer of llm should be in the same device
device_map["vpm"] = device_id
device_map["resampler"] = device_id
device_id2 = device_map["llm.model.layers.26"]
device_map["llm.model.layers.8"] = device_id2
device_map["llm.model.layers.9"] = device_id2
device_map["llm.model.layers.10"] = device_id2
device_map["llm.model.layers.11"] = device_id2
device_map["llm.model.layers.12"] = device_id2
device_map["llm.model.layers.13"] = device_id2
device_map["llm.model.layers.14"] = device_id2
device_map["llm.model.layers.15"] = device_id2
device_map["llm.model.layers.16"] = device_id2
print(device_map)
self.model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
self.model.eval()
else:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
self.model.eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
def chat(self, input):
image = None
if "image" in input and len(input["image"]) > 10: # legacy API
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input["question"])
for msg in msgs:
contents = msg.pop('content') # support str or List[Dict]
if isinstance(contents, str):
contents = [contents]
new_cnts = []
for c in contents:
if isinstance(c, dict):
if c['type'] == 'text':
c = c['pairs']
elif c['type'] == 'image':
c = Image.open(io.BytesIO(base64.b64decode(c["pairs"]))).convert('RGB')
else:
raise ValueError("content type only support text and image.")
new_cnts.append(c)
msg['content'] = new_cnts
print(f'msgs: {str(msgs)}')
answer = self.model.chat(
image=image,
msgs=msgs,
tokenizer=self.tokenizer,
)
return answer
class MiniCPMVChat:
def __init__(self, model_path) -> None:
def __init__(self, model_path, multi_gpus=False) -> None:
if '12B' in model_path:
self.model = OmniLMM12B(model_path)
elif 'MiniCPM-Llama3-V' in model_path:
self.model = MiniCPMV2_5(model_path)
elif 'MiniCPM-V-2_6' in model_path:
self.model = MiniCPMV2_6(model_path, multi_gpus)
else:
self.model = MiniCPMV(model_path)