mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
Update to MiniCPM-Llama3-V 2.5
This commit is contained in:
25
chat.py
25
chat.py
@@ -160,11 +160,36 @@ class OmniLMM3B:
|
||||
)
|
||||
return answer
|
||||
|
||||
class MiniCPMV2_5:
|
||||
def __init__(self, model_path) -> None:
|
||||
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
self.model.eval().cuda()
|
||||
|
||||
def chat(self, input):
|
||||
try:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
|
||||
except Exception as e:
|
||||
return "Image decode error"
|
||||
|
||||
msgs = json.loads(input['question'])
|
||||
|
||||
answer = self.model.chat(
|
||||
image=image,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
sampling=True,
|
||||
temperature=0.7
|
||||
)
|
||||
return answer
|
||||
|
||||
|
||||
class OmniLMMChat:
|
||||
def __init__(self, model_path) -> None:
|
||||
if '12B' in model_path:
|
||||
self.model = OmniLMM12B(model_path)
|
||||
elif 'MiniCPM-Llama3-V' in model_path:
|
||||
self.model = MiniCPMV2_5(model_path)
|
||||
else:
|
||||
self.model = OmniLMM3B(model_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user