mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-05 02:09:20 +08:00
update name
This commit is contained in:
8
chat.py
8
chat.py
@@ -136,7 +136,7 @@ def img2base64(file_name):
|
||||
encoded_string = base64.b64encode(f.read())
|
||||
return encoded_string
|
||||
|
||||
class OmniLMM3B:
|
||||
class MiniCPMV:
|
||||
def __init__(self, model_path) -> None:
|
||||
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
@@ -184,14 +184,14 @@ class MiniCPMV2_5:
|
||||
return answer
|
||||
|
||||
|
||||
class OmniLMMChat:
|
||||
class MiniCPMVChat:
|
||||
def __init__(self, model_path) -> None:
|
||||
if '12B' in model_path:
|
||||
self.model = OmniLMM12B(model_path)
|
||||
elif 'MiniCPM-Llama3-V' in model_path:
|
||||
self.model = MiniCPMV2_5(model_path)
|
||||
else:
|
||||
self.model = OmniLMM3B(model_path)
|
||||
self.model = MiniCPMV(model_path)
|
||||
|
||||
def chat(self, input):
|
||||
return self.model.chat(input)
|
||||
@@ -200,7 +200,7 @@ class OmniLMMChat:
|
||||
if __name__ == '__main__':
|
||||
|
||||
model_path = 'openbmb/OmniLMM-12B'
|
||||
chat_model = OmniLMMChat(model_path)
|
||||
chat_model = MiniCPMVChat(model_path)
|
||||
|
||||
im_64 = img2base64('./assets/worldmap_ck.jpg')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user