update name

This commit is contained in:
Hongji Zhu
2024-05-23 19:30:48 +08:00
parent 6dde8b97ea
commit 43e7097649
4 changed files with 10 additions and 10 deletions

View File

@@ -492,13 +492,13 @@ Please refer to the following codes to run.
```python ```python
from chat import OmniLMMChat, img2base64 from chat import MiniCPMVChat, img2base64
import torch import torch
import json import json
torch.manual_seed(0) torch.manual_seed(0)
chat_model = OmniLMMChat('openbmb/MiniCPM-Llama3-V-2_5') chat_model = MiniCPMVChat('openbmb/MiniCPM-Llama3-V-2_5')
im_64 = img2base64('./assets/airplane.jpeg') im_64 = img2base64('./assets/airplane.jpeg')

View File

@@ -470,13 +470,13 @@ Please refer to the following codes to run `MiniCPM-V` and `OmniLMM`.
```python ```python
from chat import OmniLMMChat, img2base64 from chat import MiniCPMVChat, img2base64
import torch import torch
import json import json
torch.manual_seed(0) torch.manual_seed(0)
chat_model = OmniLMMChat('openbmb/MiniCPM-Llama3-V-2_5') chat_model = MiniCPMVChat('openbmb/MiniCPM-Llama3-V-2_5')
im_64 = img2base64('./assets/airplane.jpeg') im_64 = img2base64('./assets/airplane.jpeg')

View File

@@ -506,13 +506,13 @@ pip install -r requirements.txt
```python ```python
from chat import OmniLMMChat, img2base64 from chat import MiniCPMVChat, img2base64
import torch import torch
import json import json
torch.manual_seed(0) torch.manual_seed(0)
chat_model = OmniLMMChat('openbmb/MiniCPM-Llama3-V-2_5') chat_model = MiniCPMVChat('openbmb/MiniCPM-Llama3-V-2_5')
im_64 = img2base64('./assets/airplane.jpeg') im_64 = img2base64('./assets/airplane.jpeg')

View File

@@ -136,7 +136,7 @@ def img2base64(file_name):
encoded_string = base64.b64encode(f.read()) encoded_string = base64.b64encode(f.read())
return encoded_string return encoded_string
class OmniLMM3B: class MiniCPMV:
def __init__(self, model_path) -> None: def __init__(self, model_path) -> None:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16) self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@@ -184,14 +184,14 @@ class MiniCPMV2_5:
return answer return answer
class OmniLMMChat: class MiniCPMVChat:
def __init__(self, model_path) -> None: def __init__(self, model_path) -> None:
if '12B' in model_path: if '12B' in model_path:
self.model = OmniLMM12B(model_path) self.model = OmniLMM12B(model_path)
elif 'MiniCPM-Llama3-V' in model_path: elif 'MiniCPM-Llama3-V' in model_path:
self.model = MiniCPMV2_5(model_path) self.model = MiniCPMV2_5(model_path)
else: else:
self.model = OmniLMM3B(model_path) self.model = MiniCPMV(model_path)
def chat(self, input): def chat(self, input):
return self.model.chat(input) return self.model.chat(input)
@@ -200,7 +200,7 @@ class OmniLMMChat:
if __name__ == '__main__': if __name__ == '__main__':
model_path = 'openbmb/OmniLMM-12B' model_path = 'openbmb/OmniLMM-12B'
chat_model = OmniLMMChat(model_path) chat_model = MiniCPMVChat(model_path)
im_64 = img2base64('./assets/worldmap_ck.jpg') im_64 = img2base64('./assets/worldmap_ck.jpg')