diff --git a/README.md b/README.md index 1276253..2fa06b5 100644 --- a/README.md +++ b/README.md @@ -492,13 +492,13 @@ Please refer to the following codes to run. ```python -from chat import OmniLMMChat, img2base64 +from chat import MiniCPMVChat, img2base64 import torch import json torch.manual_seed(0) -chat_model = OmniLMMChat('openbmb/MiniCPM-Llama3-V-2_5') +chat_model = MiniCPMVChat('openbmb/MiniCPM-Llama3-V-2_5') im_64 = img2base64('./assets/airplane.jpeg') diff --git a/README_en.md b/README_en.md index c35f271..4c72d2e 100644 --- a/README_en.md +++ b/README_en.md @@ -470,13 +470,13 @@ Please refer to the following codes to run `MiniCPM-V` and `OmniLMM`. ```python -from chat import OmniLMMChat, img2base64 +from chat import MiniCPMVChat, img2base64 import torch import json torch.manual_seed(0) -chat_model = OmniLMMChat('openbmb/MiniCPM-Llama3-V-2_5') +chat_model = MiniCPMVChat('openbmb/MiniCPM-Llama3-V-2_5') im_64 = img2base64('./assets/airplane.jpeg') diff --git a/README_zh.md b/README_zh.md index bd2eada..3736d18 100644 --- a/README_zh.md +++ b/README_zh.md @@ -506,13 +506,13 @@ pip install -r requirements.txt ```python -from chat import OmniLMMChat, img2base64 +from chat import MiniCPMVChat, img2base64 import torch import json torch.manual_seed(0) -chat_model = OmniLMMChat('openbmb/MiniCPM-Llama3-V-2_5') +chat_model = MiniCPMVChat('openbmb/MiniCPM-Llama3-V-2_5') im_64 = img2base64('./assets/airplane.jpeg') diff --git a/chat.py b/chat.py index 77ba8f7..8dbf8ef 100644 --- a/chat.py +++ b/chat.py @@ -136,7 +136,7 @@ def img2base64(file_name): encoded_string = base64.b64encode(f.read()) return encoded_string -class OmniLMM3B: +class MiniCPMV: def __init__(self, model_path) -> None: self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16) self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) @@ -184,14 +184,14 @@ class MiniCPMV2_5: return answer -class OmniLMMChat: +class MiniCPMVChat: def __init__(self, model_path) -> None: if '12B' in model_path: self.model = OmniLMM12B(model_path) elif 'MiniCPM-Llama3-V' in model_path: self.model = MiniCPMV2_5(model_path) else: - self.model = OmniLMM3B(model_path) + self.model = MiniCPMV(model_path) def chat(self, input): return self.model.chat(input) @@ -200,7 +200,7 @@ class OmniLMMChat: if __name__ == '__main__': model_path = 'openbmb/OmniLMM-12B' - chat_model = OmniLMMChat(model_path) + chat_model = MiniCPMVChat(model_path) im_64 = img2base64('./assets/worldmap_ck.jpg')