Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
juncaipeng committed Sep 20, 2024
1 parent 438258f commit 0ddfe08
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 0 deletions.
85 changes: 85 additions & 0 deletions llm/server/tests/test_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
测试的公共脚本。
"""

import json
import os
import queue
import sys
import uuid
from functools import partial

import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import *


class OutputData:
"""接收Triton服务返回的数据"""
def __init__(self):
self._completed_requests = queue.Queue()


def triton_callback(output_data, result, error):
"""Triton客户端的回调函数"""
if error:
output_data._completed_requests.put(error)
else:
output_data._completed_requests.put(result)

def test_base(grpc_url, input_data, test_iters=1, log_level="simple"):
# 参数检查
if log_level not in ["simple", "verbose"]:
raise ValueError("log_level must be simple or verbose")

# 准备发送请求
model_name = "model"
inputs = [grpcclient.InferInput("IN", [1], np_to_triton_dtype(np.object_))]
outputs = [grpcclient.InferRequestedOutput("OUT")]
output_data = OutputData()

# 准备数据,发送请求,处理返回结果
with grpcclient.InferenceServerClient(url=grpc_url, verbose=False) as triton_client:
triton_client.start_stream(callback=partial(triton_callback, output_data))
for i in range(test_iters):
input_data = json.dumps([input_data])
inputs[0].set_data_from_numpy(np.array([input_data], dtype=np.object_))

# 发送请求
triton_client.async_stream_infer(model_name=model_name,
inputs=inputs,
request_id="{}".format(i),
outputs=outputs)
# 处理返回结果
print("output_data:")
while True:
output_item = output_data._completed_requests.get(timeout=10)
if type(output_item) == InferenceServerException:
print(f"Exception: status is {output_item.status()}, msg is {output_item.message()}")
break
else:
result = json.loads(output_item.as_numpy("OUT")[0])
result = result[0] if isinstance(result, list) else result
if result.get("is_end") == 1 or result.get("error_msg"):
print(f"\n {result} \n")
break
else:
if log_level == "simple":
print(result['token'] if 'token' in result else result['token_ids'][0], end="")
else:
print(result)

if __name__ == "__main__":
input_data = {
"req_id": 0,
"text": "hello",
"seq_len": 1024,
"min_dec_len": 2,
"penalty_score": 1.0,
"temperature": 0.8,
"topp": 0.8,
"frequency_score": 0.1,
"presence_score": 0.0
}
grpc_url = "0.0.0.0:8891"
test_base(grpc_url=grpc_url, input_data=input_data)
76 changes: 76 additions & 0 deletions llm/server/tests/test_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import argparse
import json
import uuid
from datetime import datetime

import httpx
import requests


def http_no_stream(url, data):
"""http协议非流式输出"""
print("--非流式接口--")
headers = {'Content-Type': 'application/json'}
#print(f"send req time: {datetime.now()}")
#resp = httpx.post(url=url, headers=headers, timeout=300, json=data)
resp = requests.post(url, headers=headers, json=data)
print(resp.text)

def http_stream(url, data, show_chunk=False):
"""http协议流式输出"""
print("--流式接口--")
headers = {'Content-Type': 'application/json'}
data = data.copy()
data["stream"] = True
#print(f"send req time: {datetime.now()}")
#with httpx.stream("POST", url, headers=headers, timeout=300,json=data) as r:
with requests.post(url, json=data, headers=headers, timeout=300, stream=True) as r:
result = ""
for chunk in r.iter_lines():
if chunk:
resp = json.loads(chunk)
if resp["error_msg"] != "" or resp["error_code"] != 0:
print(resp)
return
else:
result += resp.get("result", "")
if show_chunk:
print(resp)
print(f"Result: {result}")

def parse_args():
"""
获取命令行参数
"""
parser = argparse.ArgumentParser()
parser.add_argument("--http_host", default="10.95.147.146", type=str, help="host to the http server")
parser.add_argument("--http_port", default=8894, type=int, help="port to the http server")
parser.add_argument("-o", "--open_source_model", action="store_true", help="test eb_model or open_source_model")
args = parser.parse_args()
return args

if __name__ == '__main__':
args = parse_args()
url = f"http://{args.http_host}:{args.http_port}/v1/chat/completions"
print(f"url: {url}")
print("\n\n=====单轮对话测试,返回正确结果=====")
data = {
"req_id": str(uuid.uuid4()),
"text": "hello",
"max_dec_len": 1024,
"min_dec_len": 2,
"penalty_score": 1.0,
"temperature": 0.8,
"topp": 0,
"frequency_score": 0.1,
"presence_score": 0.0,
"timeout": 600,
"benchmark": True,
}
http_no_stream(url, data)
http_stream(url, data)

print("\n\n=====单轮对话测试缺省参数,返回正确结果=====")
data = {"text": "hello"}
http_no_stream(url, data)
http_stream(url, data)

0 comments on commit 0ddfe08

Please sign in to comment.