From 27e4acc741b77f5b0147bbffba267e2a8a76d0e3 Mon Sep 17 00:00:00 2001 From: Hongji Zhu Date: Wed, 17 Apr 2024 21:19:31 +0800 Subject: [PATCH] Add web demo code --- web_demo.py | 260 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 web_demo.py diff --git a/web_demo.py b/web_demo.py new file mode 100644 index 0000000..ad9c54a --- /dev/null +++ b/web_demo.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python +# encoding: utf-8 +import gradio as gr +from PIL import Image +import traceback +import re +import torch +import argparse +from transformers import AutoModel, AutoTokenizer + +# README, How to run demo on different devices +# For Nvidia GPUs support BF16 (like A100, H100, RTX3090) +# python web_demo.py --device cuda --dtype bf16 + +# For Nvidia GPUs do NOT support BF16 (like V100, T4, RTX2080) +# python web_demo.py --device cuda --dtype fp16 + +# For Mac with MPS (Apple silicon or AMD GPUs). +# PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo.py --device mps --dtype fp16 + +# Argparser +parser = argparse.ArgumentParser(description='demo') +parser.add_argument('--device', type=str, default='cuda', help='cuda or mps') +parser.add_argument('--dtype', type=str, default='bf16', help='bf16 or fp16') +args = parser.parse_args() +device = args.device +assert device in ['cuda', 'mps'] +if args.dtype == 'bf16': + dtype = torch.bfloat16 +else: + dtype = torch.float16 + +# Load model +model_path = 'openbmb/MiniCPM-V-2' +model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +model = model.to(device=device, dtype=dtype) +model.eval() + + + +ERROR_MSG = "Error, please retry" +model_name = 'MiniCPM-V 2.0' + +form_radio = { + 'choices': ['Beam Search', 'Sampling'], + #'value': 'Beam Search', + 'value': 'Sampling', + 'interactive': True, + 'label': 'Decode Type' +} +# Beam Form +num_beams_slider = { + 'minimum': 0, + 'maximum': 5, + 'value': 3, + 'step': 1, + 'interactive': True, + 'label': 'Num Beams' +} +repetition_penalty_slider = { + 'minimum': 0, + 'maximum': 3, + 'value': 1.2, + 'step': 0.01, + 'interactive': True, + 'label': 'Repetition Penalty' +} +repetition_penalty_slider2 = { + 'minimum': 0, + 'maximum': 3, + 'value': 1.05, + 'step': 0.01, + 'interactive': True, + 'label': 'Repetition Penalty' +} +max_new_tokens_slider = { + 'minimum': 1, + 'maximum': 4096, + 'value': 1024, + 'step': 1, + 'interactive': True, + 'label': 'Max New Tokens' +} + +top_p_slider = { + 'minimum': 0, + 'maximum': 1, + 'value': 0.8, + 'step': 0.05, + 'interactive': True, + 'label': 'Top P' +} +top_k_slider = { + 'minimum': 0, + 'maximum': 200, + 'value': 100, + 'step': 1, + 'interactive': True, + 'label': 'Top K' +} +temperature_slider = { + 'minimum': 0, + 'maximum': 2, + 'value': 0.7, + 'step': 0.05, + 'interactive': True, + 'label': 'Temperature' +} + + +def create_component(params, comp='Slider'): + if comp == 'Slider': + return gr.Slider( + minimum=params['minimum'], + maximum=params['maximum'], + value=params['value'], + step=params['step'], + interactive=params['interactive'], + label=params['label'] + ) + elif comp == 'Radio': + return gr.Radio( + choices=params['choices'], + value=params['value'], + interactive=params['interactive'], + label=params['label'] + ) + elif comp == 'Button': + return gr.Button( + value=params['value'], + interactive=True + ) + + +def chat(img, msgs, ctx, params=None, vision_hidden_states=None): + default_params = {"num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024} + if params is None: + params = default_params + if img is None: + return -1, "Error, invalid image, please upload a new image", None, None + try: + image = img.convert('RGB') + answer, context, _ = model.chat( + image=image, + msgs=msgs, + context=None, + tokenizer=tokenizer, + **params + ) + res = re.sub(r'(.*)', '', answer) + res = res.replace('', '') + res = res.replace('', '') + res = res.replace('', '') + answer = res.replace('', '') + return -1, answer, None, None + except Exception as err: + print(err) + traceback.print_exc() + return -1, ERROR_MSG, None, None + + +def upload_img(image, _chatbot, _app_session): + image = Image.fromarray(image) + + _app_session['sts']=None + _app_session['ctx']=[] + _app_session['img']=image + _chatbot.append(('', 'Image uploaded successfully, you can talk to me now')) + return _chatbot, _app_session + + +def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature): + if _app_cfg.get('ctx', None) is None: + _chat_bot.append((_question, 'Please upload an image to start')) + return '', _chat_bot, _app_cfg + + _context = _app_cfg['ctx'].copy() + if _context: + _context.append({"role": "user", "content": _question}) + else: + _context = [{"role": "user", "content": _question}] + print(':', _question) + + if params_form == 'Beam Search': + params = { + 'sampling': False, + 'num_beams': num_beams, + 'repetition_penalty': repetition_penalty, + "max_new_tokens": 896 + } + else: + params = { + 'sampling': True, + 'top_p': top_p, + 'top_k': top_k, + 'temperature': temperature, + 'repetition_penalty': repetition_penalty_2, + "max_new_tokens": 896 + } + code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params) + print(':', _answer) + + _context.append({"role": "assistant", "content": _answer}) + _chat_bot.append((_question, _answer)) + if code == 0: + _app_cfg['ctx']=_context + _app_cfg['sts']=sts + return '', _chat_bot, _app_cfg + + +def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature): + if len(_chat_bot) <= 1: + _chat_bot.append(('Regenerate', 'No question for regeneration.')) + return '', _chat_bot, _app_cfg + elif _chat_bot[-1][0] == 'Regenerate': + return '', _chat_bot, _app_cfg + else: + _question = _chat_bot[-1][0] + _chat_bot = _chat_bot[:-1] + _app_cfg['ctx'] = _app_cfg['ctx'][:-2] + return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature) + + + +with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(scale=1, min_width=300): + params_form = create_component(form_radio, comp='Radio') + with gr.Accordion("Beam Search") as beams_according: + num_beams = create_component(num_beams_slider) + repetition_penalty = create_component(repetition_penalty_slider) + with gr.Accordion("Sampling") as sampling_according: + top_p = create_component(top_p_slider) + top_k = create_component(top_k_slider) + temperature = create_component(temperature_slider) + repetition_penalty_2 = create_component(repetition_penalty_slider2) + regenerate = create_component({'value': 'Regenerate'}, comp='Button') + with gr.Column(scale=3, min_width=500): + app_session = gr.State({'sts':None,'ctx':None,'img':None}) + bt_pic = gr.Image(label="Upload an image to start") + chat_bot = gr.Chatbot(label=f"Chat with {model_name}") + txt_message = gr.Textbox(label="Input text") + + regenerate.click( + regenerate_button_clicked, + [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature], + [txt_message, chat_bot, app_session] + ) + txt_message.submit( + respond, + [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature], + [txt_message, chat_bot, app_session] + ) + bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session]) + +# launch +demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0") +