Files
MiniCPM-o/web_demos/minicpm-o_2.6/model_server.py
2025-01-24 16:32:26 +08:00

937 lines
37 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import json
import asyncio
import numpy as np
import os, sys, io
import threading
import time
import aiofiles
import librosa
import soundfile
import wave
from typing import Dict, List, Any, Optional
import argparse
import logging
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer, AutoProcessor
import uvicorn
from fastapi import FastAPI, Header, Query, Request, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse, StreamingResponse
cur_path = os.path.split(os.path.realpath(__file__))[0]
sys.path.append(os.path.abspath(cur_path))
import vad_utils
def setup_logger():
logger = logging.getLogger("api_logger")
logger.setLevel(logging.DEBUG)
# Create formatter
formatter = logging.Formatter(
'%(asctime)s.%(msecs)03d-%(levelname)s-[%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Create handlers for stdout and stderr
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.INFO) # INFO and DEBUG go to stdout
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(lambda record: record.levelno <= logging.INFO)
stderr_handler = logging.StreamHandler(sys.stderr)
stderr_handler.setLevel(logging.WARNING) # WARNING, ERROR, CRITICAL go to stderr
stderr_handler.setFormatter(formatter)
# Add handlers to logger
logger.addHandler(stdout_handler)
logger.addHandler(stderr_handler)
return logger
app = FastAPI()
logger = setup_logger()
ap = argparse.ArgumentParser()
ap.add_argument('--port', type=int , default=32550)
ap.add_argument('--model', type=str , default="openbmb/MiniCPM-o-2_6", help="huggingface model name or local path")
args = ap.parse_args()
class StreamManager:
def __init__(self):
self.uid = None
self.is_streaming_complete = threading.Event()
self.conversation_started = threading.Event()
self.last_request_time = None
self.last_stream_time = None
self.timeout = 900 # seconds timeout
self.stream_timeout = 3 # seconds no stream
self.num_stream = 0
self.stream_started = False
self.stop_response = False
# VAD settings
self.vad_options = vad_utils.VadOptions()
self.vad_sequence_length = 5
self.vad_sequence = []
self.audio_prefill = []
self.audio_input = []
self.image_prefill = None
self.audio_chunk = 200
# customized options
self.customized_audio = None
self.customized_options = None
# Omni model
self.target_dtype = torch.bfloat16
self.device='cuda:0'
self.minicpmo_model_path = args.model #"openbmb/MiniCPM-o-2_6"
self.model_version = "2.6"
with torch.no_grad():
self.minicpmo_model = AutoModel.from_pretrained(self.minicpmo_model_path, trust_remote_code=True, torch_dtype=self.target_dtype, attn_implementation='sdpa')
self.minicpmo_tokenizer = AutoTokenizer.from_pretrained(self.minicpmo_model_path, trust_remote_code=True)
self.minicpmo_model.init_tts()
# self.minicpmo_model.tts.float()
self.minicpmo_model.to(self.device).eval()
self.ref_path_video_default = "assets/ref_audios/video_default.wav"
self.ref_path_default = "assets/ref_audios/default.wav"
self.ref_path_female = "assets/ref_audios/female_example.wav"
self.ref_path_male = "assets/ref_audios/male_example.wav"
self.input_audio_id = 0
self.input_audio_vad_id = 0
self.input_image_id = 0
self.output_audio_id = 0
self.flag_decode = False
self.cnts = None
self.all_start_time = time.time()
self.session_id = 233
self.sys_prompt_flag = False
self.vad_time = 0
self.ls_time = 0
self.msg_type = 1
self.speaking_time_stamp = 0
self.cycle_wait_time = 12800/24000 + 0.15
self.extra_wait_time = 2.5
self.server_wait = True
self.past_session_id = 0
self.sys_prompt_init(0)
self.session_id += 1
def start_conversation(self):
logger.info(f"uid {self.uid}: new conversation started.")
self.conversation_started.set()
self.stop_response = False
def update_last_request_time(self):
self.last_request_time = time.time()
#logger.info(f"update last_request_time {self.last_request_time}")
def update_last_stream_time(self):
self.last_stream_time = time.time()
#logger.info(f"update last_stream_time {self.last_stream_time}")
def move_to_device(self, obj, device):
if isinstance(obj, torch.Tensor):
obj_ = obj.to(device)
if (obj_.dtype == torch.float) or (obj_.dtype == torch.half):
# cast to `torch.bfloat16`
obj_ = obj_.to(self.target_dtype)
return obj_
elif isinstance(obj, dict):
return {key: self.move_to_device(value, device) for key, value in obj.items()}
elif isinstance(obj, list):
return [self.move_to_device(item, device) for item in obj]
elif isinstance(obj, tuple):
return tuple(self.move_to_device(item, device) for item in obj)
elif isinstance(obj, set):
return {self.move_to_device(item, device) for item in obj}
else:
return obj
def reset(self):
logger.info("reset")
self.is_streaming_complete.clear()
self.conversation_started.clear()
self.last_request_time = None
self.last_stream_time = None
self.audio_buffer_raw = bytearray()
self.num_stream = 0
self.stream_started = False
self.stop_response = False
# self.customized_audio = None
# self.customized_options = None
# clear model
self.clear()
def merge_wav_files(self, input_bytes_list, output_file):
with wave.open(io.BytesIO(input_bytes_list[0]), 'rb') as wav:
params = wav.getparams()
n_channels, sampwidth, framerate, n_frames, comptype, compname = params
with wave.open(output_file, 'wb') as output_wav:
output_wav.setnchannels(n_channels)
output_wav.setsampwidth(sampwidth)
output_wav.setframerate(framerate)
output_wav.setcomptype(comptype, compname)
for wav_bytes in input_bytes_list:
with wave.open(io.BytesIO(wav_bytes), 'rb') as wav:
output_wav.writeframes(wav.readframes(wav.getnframes()))
def is_timed_out(self):
if self.last_request_time is not None:
return time.time() - self.last_request_time > self.timeout
return False
def no_active_stream(self):
if self.last_stream_time is not None and self.stream_started:
no_stream_duration = time.time() - self.last_stream_time
if no_stream_duration > self.stream_timeout:
#logger.info(f"no active stream for {no_stream_duration} secs.")
return True
return False
def sys_prompt_init(self, msg_type):
if self.past_session_id == self.session_id:
return
logger.info("### sys_prompt_init ###")
logger.info(f'msg_type is {msg_type}')
if msg_type <= 1: #audio
audio_voice_clone_prompt = "Use the voice in the audio prompt to synthesize new content."
audio_assistant_prompt = "You are a helpful assistant with the above voice style."
ref_path = self.ref_path_default
if self.customized_options is not None:
audio_voice_clone_prompt = self.customized_options['voice_clone_prompt']
audio_assistant_prompt = self.customized_options['assistant_prompt']
if self.customized_options['use_audio_prompt'] == 1:
ref_path = self.ref_path_default
elif self.customized_options['use_audio_prompt'] == 2:
ref_path = self.ref_path_female
elif self.customized_options['use_audio_prompt'] == 3:
ref_path = self.ref_path_male
audio_prompt, sr = librosa.load(ref_path, sr=16000, mono=True)
sys_msg = {'role': 'user', 'content': [audio_voice_clone_prompt + "\n", audio_prompt, "\n" + audio_assistant_prompt]}
elif msg_type == 2: #video
voice_clone_prompt="你是一个AI助手。你能接受视频音频和文本输入并输出语音和文本。模仿输入音频中的声音特征。"
assistant_prompt="作为助手,你将使用这种声音风格说话。"
ref_path = self.ref_path_video_default
if self.customized_options is not None:
voice_clone_prompt = self.customized_options['voice_clone_prompt']
assistant_prompt = self.customized_options['assistant_prompt']
if self.customized_options['use_audio_prompt'] == 1:
ref_path = self.ref_path_default
elif self.customized_options['use_audio_prompt'] == 2:
ref_path = self.ref_path_female
elif self.customized_options['use_audio_prompt'] == 3:
ref_path = self.ref_path_male
audio_prompt, sr = librosa.load(ref_path, sr=16000, mono=True)
sys_msg = {'role': 'user', 'content': [voice_clone_prompt, audio_prompt, assistant_prompt]}
# elif msg_type == 3: #user start
# assistant_prompt="作为助手,你将使用这种声音风格说话。"
# if self.customized_options is not None:
# assistant_prompt = self.customized_options['assistant_prompt']
# sys_msg = {'role': 'user', 'content': [assistant_prompt]}
self.msg_type = msg_type
msgs = [sys_msg]
if self.customized_options is not None:
if self.customized_options['use_audio_prompt'] > 0:
self.minicpmo_model.streaming_prefill(
session_id=str(self.session_id),
msgs=msgs,
tokenizer=self.minicpmo_tokenizer,
)
if msg_type == 0:
self.minicpmo_model.streaming_prefill(
session_id=str(self.session_id),
msgs=msgs,
tokenizer=self.minicpmo_tokenizer,
)
self.savedir = os.path.join(f"./log_data/{args.port}/", str(time.time()))
if not os.path.exists(self.savedir):
os.makedirs(self.savedir)
if not os.path.exists(self.savedir + "/input_audio_log"):
os.makedirs(self.savedir + "/input_audio_log")
if not os.path.exists(self.savedir + "/input_audio_vad_log"):
os.makedirs(self.savedir + "/input_audio_vad_log")
if not os.path.exists(self.savedir + "/input_image_log"):
os.makedirs(self.savedir + "/input_image_log")
if not os.path.exists(self.savedir + "/output_audio_log"):
os.makedirs(self.savedir + "/output_audio_log")
if not os.path.exists(self.savedir + "/feedback_log"):
os.makedirs(self.savedir + "/feedback_log")
if not os.path.exists(self.savedir + "/input_audio"):
os.makedirs(self.savedir + "/input_audio")
self.past_session_id = self.session_id
self.audio_prefill = []
self.audio_input = []
def clear(self):
try:
self.flag_decode = False
self.stream_started = False
self.cnts = None
self.vad_sequence = []
self.audio_prefill = []
self.audio_input = []
self.image_prefill = None
if self.minicpmo_model.llm_past_key_values[0][0].shape[2]>8192:
self.session_id += 1 # to clear all kv cache
self.sys_prompt_flag = False
self.vad_time = 0
self.ls_time = 0
self.msg_type = 1
except Exception as e:
raise ValueError(f"Clear error: {str(e)}")
def process_message(self, message: Dict[str, Any]):
try:
# Process content items
audio_data = None
image_data = None
for content_item in message["content"]:
if content_item["type"] == "stop_response":
logger.info("process_message: received request to stop_response")
self.stop_response = True
return "stop"
elif content_item["type"] == "input_audio":
audio_data = content_item["input_audio"]["data"]
audio_timestamp = content_item["input_audio"].get("timestamp", "")
elif content_item["type"] == "image_data":
image_data = content_item["image_data"]["data"]
if audio_data is None:
return "empty audio"
if self.conversation_started.is_set() and self.is_streaming_complete.is_set():
logger.info("conversation not started or still in generation, skip stream message.")
return "skip"
if self.flag_decode:
return "skip"
try:
audio_bytes = base64.b64decode(audio_data)
image = None
if image_data is not None:
if len(image_data) > 0:
image_bytes = base64.b64decode(image_data)
image_buffer = io.BytesIO(image_bytes)
image_buffer.seek(0)
image = Image.open(image_buffer)
# logger.info("read image")
if self.sys_prompt_flag is False:
self.all_start_time = time.time()
self.sys_prompt_flag = True
if image_data is not None:
self.sys_prompt_init(2)
else:
self.sys_prompt_init(1)
self.prefill(audio_bytes, image, False)
self.vad_sequence.append(audio_bytes)
if len(self.vad_sequence) < self.vad_sequence_length:
# logger.info('length of vad_sequence is {}, insufficient'.format(self.vad_sequence_length))
return "done"
elif len(self.vad_sequence) > self.vad_sequence_length:
# logger.info('length of vad_sequence exceeds {}'.format(self.vad_sequence_length))
self.vad_sequence.pop(0)
self.vad_check_audio_bytes(audio_bytes, image, 16000)
return "done"
except Exception as e:
raise ValueError(f"Audio processing error: {str(e)}")
except Exception as e:
raise ValueError(f"Message processing error: {str(e)}")
def resample_audio(self, input_path, src_sr, tar_sr, output_path):
audio_data, _ = librosa.load(input_path, sr=src_sr)
audio_new = librosa.resample(audio_data, orig_sr=src_sr, target_sr=tar_sr)
soundfile.write(output_path, audio_new, tar_sr)
def calculate_rms(self, input_path, sr):
audio_data, _ = librosa.load(input_path, sr=sr)
return (np.sqrt(np.mean(audio_data**2)) > 0.002)
def vad_check_audio_bytes(self, audio, image, sr):
try:
input_audio_vad_path = self.savedir + f"/input_audio_vad_log/vad_{self.input_audio_vad_id}.wav"
self.input_audio_vad_id += 1
self.merge_wav_files(self.vad_sequence, input_audio_vad_path)
with open(input_audio_vad_path,"rb") as f:
temp_audio = f.read()
dur_vad, vad_audio_bytes, time_vad = vad_utils.run_vad(temp_audio, sr, self.vad_options)
if self.customized_options is not None:
vad_threshold = 1 - self.customized_options['vad_threshold']
else:
vad_threshold = 0.2
if self.calculate_rms(input_audio_vad_path, sr) and dur_vad > 0.4:
if self.stream_started == False:
self.vad_time = time.time()
self.stream_started = True
elif dur_vad < vad_threshold:
if self.stream_started:
self.stream_started = False
if (time.time() - self.vad_time >= 0.6):
self.prefill(audio, image, True)
self.is_streaming_complete.set()
# self.ls_time = time.time()
except Exception as e:
logger.error(f"VAD error: {e}")
raise
return
def prefill(self, audio, image, is_end):
if self.server_wait:
now = time.time()
await_time = self.speaking_time_stamp - now + self.extra_wait_time
if await_time > 0:
return False
if self.flag_decode:
return False
if image is not None:
self.image_prefill = image
try:
if is_end == False:
self.audio_prefill.append(audio)
self.audio_input.append(audio)
slice_nums = 1
if is_end and self.customized_options is not None:
if self.customized_options['hd_video']:
slice_nums = 6
else:
return True
if (len(self.audio_prefill) == (1000/self.audio_chunk)) or (is_end and len(self.audio_prefill)>0):
time_prefill = time.time()
input_audio_path = self.savedir + f"/input_audio_log/input_audio_{self.input_audio_id}.wav"
self.merge_wav_files(self.audio_prefill, input_audio_path)
with open(input_audio_path,"rb") as wav_io:
signal, sr = soundfile.read(wav_io, dtype='float32')
soundfile.write(input_audio_path, signal, 16000)
audio_np, sr = librosa.load(input_audio_path, sr=16000, mono=True)
self.audio_prefill = []
if len(audio_np) > 16000:
audio_np = audio_np[:16000]
with torch.no_grad():
if self.image_prefill is not None:
input_image_path = self.savedir + f'/input_image_log/input_image_{self.input_audio_id}.png'
self.image_prefill.save(input_image_path, 'PNG')
self.image_prefill = self.image_prefill.convert("RGB")
cnts = None
if self.image_prefill is not None:
cnts = ["<unit>", self.image_prefill, audio_np]
else:
cnts = [audio_np]
if cnts is not None:
msg = {"role":"user", "content": cnts}
msgs = [msg]
res = self.minicpmo_model.streaming_prefill(
session_id=str(self.session_id),
msgs=msgs,
tokenizer=self.minicpmo_tokenizer,
max_slice_nums=slice_nums,
)
self.input_audio_id += 1
return True
except Exception as e:
logger.error(f"prefill error: {e}")
import traceback
traceback.print_exc()
raise
def generate_end(self):
self.input_audio_id += 10
self.output_audio_id += 10
self.flag_decode = False
self.reset()
return
async def generate(self):
""" return audio bytes and response text (optional) """
if self.stop_response:
self.generate_end()
return
self.flag_decode = True
try:
with torch.no_grad():
logger.info("=== model gen start ===")
time_gen = time.time()
input_audio_path = self.savedir + f"/input_audio/all_input_audio_{self.input_audio_id}.wav"
self.merge_wav_files(self.audio_input, input_audio_path)
audio_stream = None
try:
with open(input_audio_path, 'rb') as wav_file:
audio_stream = wav_file.read()
except FileNotFoundError:
print(f"File {input_audio_path} not found.")
yield base64.b64encode(audio_stream).decode('utf-8'), "assistant:\n"
print('=== gen start: ', time.time() - time_gen)
first_time = True
temp_time = time.time()
temp_time1 = time.time()
with torch.inference_mode():
if self.stop_response:
self.generate_end()
return
self.minicpmo_model.config.stream_input=True
msg = {"role":"user", "content": self.cnts}
msgs = [msg]
text = ''
self.speaking_time_stamp = time.time()
try:
for r in self.minicpmo_model.streaming_generate(
session_id=str(self.session_id),
tokenizer=self.minicpmo_tokenizer,
generate_audio=True,
# enable_regenerate=True,
):
if self.stop_response:
self.generate_end()
return
audio_np, sr, text = r["audio_wav"], r["sampling_rate"], r["text"]
output_audio_path = self.savedir + f'/output_audio_log/output_audio_{self.output_audio_id}.wav'
self.output_audio_id += 1
soundfile.write(output_audio_path, audio_np, samplerate=sr)
audio_stream = None
try:
with open(output_audio_path, 'rb') as wav_file:
audio_stream = wav_file.read()
except FileNotFoundError:
print(f"File {output_audio_path} not found.")
temp_time1 = time.time()
print('text: ', text)
yield base64.b64encode(audio_stream).decode('utf-8'), text
self.speaking_time_stamp += self.cycle_wait_time
except Exception as e:
logger.error(f"Error happened during generation: {str(e)}")
yield None, '\n<end>'
except Exception as e:
logger.error(f"发生异常:{e}")
import traceback
traceback.print_exc()
raise
finally:
logger.info(f"uid {self.uid}: generation finished!")
self.generate_end()
async def check_activity(self):
while True:
# Check for overall inactivity (30 minutes)
if self.is_timed_out():
self.reset()
if self.no_active_stream() and not self.is_streaming_complete.is_set():
self.is_streaming_complete.set()
await asyncio.sleep(1) # Check every second
def upload_customized_audio(self, audio_data, audio_fmt):
self.customized_audio = None
try:
if audio_data is not None and len(audio_data) > 0:
# if audio_fmt == "mp3" or audio_fmt == "wav":
audio_bytes = base64.b64decode(audio_data)
fio = io.BytesIO(audio_bytes)
fio.seek(0)
audio_np, sr = librosa.load(fio, sr=16000, mono=True)
if audio_np is not None and len(audio_np) > 1000:
output_audio_path = self.savedir + f'/customized_audio.wav'
soundfile.write(output_audio_path, audio_np, sr)
self.customized_audio = output_audio_path
logger.info(f"processed customized {audio_fmt} audio")
print(audio_np.shape, type(audio_np), sr)
else:
logger.info(f"empty customized audio, use default value instead.")
self.customized_audio = None
except Exception as e:
raise ValueError(f"Process customized audio error: {str(e)}")
def update_customized_options(self, uid, options):
self.customized_options = None
if options is None:
raise ValueError("Invalid None type for options, expected dict type")
self.customized_options = options
logger.info(f"uid: {uid} set customized_options to {options}")
stream_manager = StreamManager()
@app.on_event("startup")
async def startup_event():
logger.info("Starting application and activity checker")
asyncio.create_task(stream_manager.check_activity())
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down application")
@app.post("/stream")
@app.post("/api/v1/stream")
async def stream(request: Request, uid: Optional[str] = Header(None)):
global stream_manager
stream_manager.update_last_request_time()
stream_manager.update_last_stream_time()
if not uid:
raise HTTPException(status_code=400, detail="Missing uid in headers")
if stream_manager.uid is not None and stream_manager.uid != uid:
logger.error(f"uid changed during steram: previous uid {stream_manager.uid}, new uid {uid}")
raise HTTPException(status_code=400, detail="uid changed in stream")
try:
# Parse JSON request
data = await request.json()
# Validate basic structure
if not isinstance(data, dict) or "messages" not in data:
raise HTTPException(status_code=400, detail="Invalid request format")
# Process messages
reason = ""
for message in data["messages"]:
if not isinstance(message, dict) or "role" not in message or "content" not in message:
raise HTTPException(status_code=400, detail="Invalid message format")
reason = stream_manager.process_message(message)
# Return response using uid from header
response = {
"id": uid,
"choices": {
"role": "assistant",
"content": "success",
"finish_reason": reason
}
}
return JSONResponse(content=response, status_code=200)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.websocket("/ws/stream")
@app.websocket("/ws/api/v1/stream")
async def websocket_stream(websocket: WebSocket,
uid: Optional[str] = Query(None)):
global stream_manager
if not uid:
await websocket.close(code=400, reason="Missing uid in request")
return
# Accept the WebSocket connection
await websocket.accept()
#if stream_manager.uid is not None and stream_manager.uid != uid:
# logger.error(f"uid changed during steram: previous uid {stream_manager.uid}, new uid {uid}")
# await websocket.close(code=400, reason="Uid changed in stream.")
# return
try:
while True:
# Continuously listen for incoming messages from the client
data = await websocket.receive_text()
# Parse JSON request
try:
request_data = json.loads(data)
except json.JSONDecodeError:
await websocket.send_json({"error": "Invalid JSON"})
continue
stream_manager.update_last_request_time()
stream_manager.update_last_stream_time()
if stream_manager.uid is not None and stream_manager.uid != uid:
logger.error(f"uid changed during stream: previous uid {stream_manager.uid}, new uid {uid}")
await websocket.send_json({"error": "UID changed in stream"})
continue
# Validate basic structure
if not isinstance(request_data, dict) or "messages" not in request_data:
await websocket.send_json({"error": "Invalid request format"})
continue
# Process messages
try:
reason = ""
for message in request_data["messages"]:
if not isinstance(message, dict) or "role" not in message or "content" not in message:
await websocket.send_json({"error": "Invalid message format"})
continue
reason = stream_manager.process_message(message)
# Respond with success message
response = {
"id": uid,
"choices": {
"role": "assistant",
"content": "success",
"finish_reason": reason,
},
}
await websocket.send_json(response)
except WebSocketDisconnect:
# Handle WebSocket disconnection
break
except Exception as e:
logger.error(f"process message error: {str(e)}")
await websocket.close(code=1011, reason =f"Internal server error: {str(e)}")
except WebSocketDisconnect:
# Handle WebSocket disconnection
return
except Exception as e:
logger.error(f"ws_stream error: {str(e)}")
await websocket.close(code=1011, reason =f"Unexpected error: {str(e)}")
async def generate_sse_response(request: Request, uid: Optional[str] = Header(None)):
global stream_manager
print(f"uid: {uid}")
try:
# Wait for streaming to complete or timeout
while not stream_manager.is_streaming_complete.is_set():
# if stream_manager.is_timed_out():
# yield f"data: {json.dumps({'error': 'Stream timeout'})}\n\n"
# return
# print(f"{uid} whille not stream_manager.is_streaming_complete.is_set(), asyncio.sleep(0.1)")
await asyncio.sleep(0.1)
logger.info("streaming complete\n")
# Generate response
try:
yield f"event: message\n"
async for audio, text in stream_manager.generate():
if text == "stop":
break
res = {
"id": stream_manager.uid,
"response_id": stream_manager.output_audio_id,
"choices": [
{
"role": "assistant",
"audio": audio,
"text": text,
"finish_reason": "processing"
}
]
}
# logger.info("generate_sse_response yield response")
yield f"data: {json.dumps(res)}\n\n"
await asyncio.sleep(0)
except Exception as e:
logger.error(f"Error while generation: {str(e)}")
yield f'data:{{"error": "{str(exc)}"}}\n\n'
except Exception as e:
yield f'data:{{"error": "{str(e)}"}}\n\n'
@app.post("/completions")
@app.post("/api/v1/completions")
async def completions(request: Request, uid: Optional[str] = Header(None)):
global stream_manager
if not uid:
raise HTTPException(status_code=400, detail="Missing uid in headers")
try:
# if stream_manager.uid is not None and stream_manager.uid != uid:
if stream_manager.uid != uid:
# stream_manager.stop_response = True
# logger.info(f"uid changed, reset model: previous uid {stream_manager.uid}, new uid {uid}")
stream_manager.session_id += 1
stream_manager.sys_prompt_flag = False
stream_manager.reset()
# raise HTTPException(
# status_code=409,
# detail="User id changed, reset context."
# )
stream_manager.speaking_time_stamp = 0
stream_manager.update_last_request_time()
stream_manager.uid = uid
stream_manager.start_conversation()
data = await request.json()
return StreamingResponse(
generate_sse_response(request, uid),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
except asyncio.TimeoutError:
raise HTTPException(
status_code=503,
detail="Server busy, please try again later"
)
except Exception as e:
logger.error(f"Error processing request for user {uid}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/stop")
@app.post("/api/v1/stop")
async def stop_response(request: Request, uid: Optional[str] = Header(None)):
if not uid:
raise HTTPException(status_code=400, detail="Missing uid in headers")
global stream_manager
# stream_manager.session_id += 1
logger.info(f"uid {uid}: received stop_response")
stream_manager.stop_response = True
response = {
"id": uid,
"choices": {
"role": "assistant",
"content": "success",
"finish_reason": "stop"
}
}
return JSONResponse(content=response, status_code=200)
@app.post("/feedback")
@app.post("/api/v1/feedback")
async def feedback(request: Request, uid: Optional[str] = Header(None)):
global stream_manager
# Validate the 'uid' header
if not uid:
raise HTTPException(status_code=400, detail="Missing 'uid' header")
try:
data = await request.json()
if "response_id" not in data or "rating" not in data:
raise HTTPException(status_code=400, detail="Invalid request: must have response_id and rating")
response_id = data.get("response_id", "")
rating = data.get("rating", "")
comment = data.get("comment", "")
# Validate the rating
if rating not in ["like", "dislike"]:
raise HTTPException(status_code=400, detail=f"Invalid rating value: {rating}")
# Define the log file path
log_file_path = f"{stream_manager.savedir}/feedback_log/{response_id}.{rating}"
# Write the feedback to the file asynchronously
async with aiofiles.open(log_file_path, mode="a") as file:
await file.write(f"model: {stream_manager.minicpmo_model_path}\nuid {uid}: {comment}\n")
response = {
"id": uid,
"choices": {
"role": "assistant",
"content": "success",
"finish_reason": "done"
}
}
return JSONResponse(content=response, status_code=200)
except Exception as e:
logger.error(f"Error processing feedback for user {uid}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/init_options")
@app.post("/api/v1/init_options")
async def init_options(request: Request, uid: Optional[str] = Header(None)):
global stream_manager
stream_manager.update_last_request_time()
if not uid:
raise HTTPException(status_code=400, detail="Missing uid in headers")
try:
# Parse JSON request
data = await request.json()
# Validate basic structure
if not isinstance(data, dict) or "messages" not in data:
raise HTTPException(status_code=400, detail="Invalid request format")
messages = data.get("messages", [])
for message in messages:
if not isinstance(message, dict) or "role" not in message or "content" not in message:
raise HTTPException(status_code=400, detail="Invalid message format")
for content in message.get("content", []):
if content["type"] == "input_audio":
audio_data = content["input_audio"].get("data", "")
audio_fmt = content["input_audio"].get("format", "")
stream_manager.upload_customized_audio(audio_data, audio_fmt)
elif content["type"] == "options":
stream_manager.update_customized_options(uid, content["options"])
else:
ctype = content["type"]
raise HTTPException(status_code=400, detail=f"Invalid content type: {ctype}")
version = stream_manager.model_version
print(version)
response = {
"id": uid,
"choices": {
"role": "assistant",
"content": version,
"finish_reason": "done"
}
}
return JSONResponse(content=response, status_code=200)
except Exception as e:
raise HTTPException(status_code=400, detail=f"init options error: {str(e)}")
@app.get('/health')
@app.get('/api/v1/health')
async def health_check():
return {"status": "OK"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=args.port)