From eb53ccbc1967d7409976ab6ee7db676c69d265ee Mon Sep 17 00:00:00 2001 From: iflamed Date: Wed, 10 Jul 2024 23:11:45 +0800 Subject: [PATCH] add fastapi client --- README.md | 7 ++- runtime/python/fastapi_client.py | 78 ++++++++++++++++++++++++++++++++ runtime/python/fastapi_server.py | 7 +++ 3 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 runtime/python/fastapi_client.py diff --git a/README.md b/README.md index 23e4838..5cffa5d 100644 --- a/README.md +++ b/README.md @@ -121,10 +121,13 @@ You can get familiar with CosyVoice following this recipie. The `main.py` file has added a `TTS` api with `CosyVoice-300M-SFT` model, you can update the code based on **Basic Usage** as above. ```sh +cd runtime/python +# Set inference model +export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct # For development -fastapi dev --port 3003 +fastapi dev --port 6006 fastapi_server.py # For production -fastapi run --port 3003 +fastapi run --port 6006 fastapi_server.py ``` **Build for deployment** diff --git a/runtime/python/fastapi_client.py b/runtime/python/fastapi_client.py new file mode 100644 index 0000000..f4b3f12 --- /dev/null +++ b/runtime/python/fastapi_client.py @@ -0,0 +1,78 @@ +import argparse +import logging +import requests + +def saveResponse(path, response): + # 以二进制写入模式打开文件 + with open(path, 'wb') as file: + # 将响应的二进制内容写入文件 + file.write(response.content) + +def main(): + api = args.api_base + if args.mode == 'sft': + url = api + "/api/inference/sft" + payload={ + 'tts': args.tts_text, + 'role': args.spk_id + } + response = requests.request("POST", url, data=payload) + saveResponse(args.tts_wav, response) + elif args.mode == 'zero_shot': + url = api + "/api/inference/zero-shot" + payload={ + 'tts': args.tts_text, + 'prompt': args.prompt_text + } + files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))] + response = requests.request("POST", url, data=payload, files=files) + saveResponse(args.tts_wav, response) + elif args.mode == 'cross_lingual': + url = api + "/api/inference/cross-lingual" + payload={ + 'tts': args.tts_text, + } + files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))] + response = requests.request("POST", url, data=payload, files=files) + saveResponse(args.tts_wav, response) + else: + url = api + "/api/inference/instruct" + payload = { + 'tts': args.tts_text, + 'role': args.spk_id, + 'instruct': args.instruct_text + } + response = requests.request("POST", url, data=payload) + saveResponse(args.tts_wav, response) + logging.info("Response save to {}", args.tts_wav) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--api_base', + type=str, + default='http://127.0.0.1:6006') + parser.add_argument('--mode', + default='sft', + choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], + help='request mode') + parser.add_argument('--tts_text', + type=str, + default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?') + parser.add_argument('--spk_id', + type=str, + default='中文女') + parser.add_argument('--prompt_text', + type=str, + default='希望你以后能够做的比我还好呦。') + parser.add_argument('--prompt_wav', + type=str, + default='../../zero_shot_prompt.wav') + parser.add_argument('--instruct_text', + type=str, + default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.') + parser.add_argument('--tts_wav', + type=str, + default='demo.wav') + args = parser.parse_args() + prompt_sr, target_sr = 16000, 22050 + main() diff --git a/runtime/python/fastapi_server.py b/runtime/python/fastapi_server.py index f718373..2dbc619 100644 --- a/runtime/python/fastapi_server.py +++ b/runtime/python/fastapi_server.py @@ -1,3 +1,10 @@ +# Set inference model +# export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct +# For development +# fastapi dev --port 6006 fastapi_server.py +# For production deployment +# fastapi run --port 6006 fastapi_server.py + import os import sys import io,time