Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion test/ce/server/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,22 @@
from .request_template import TEMPLATES
from .utils import (
build_request_payload,
get_logprobs_list,
get_probs_list,
get_stream_chunks,
get_token_list,
send_request,
)

__all__ = ["build_request_payload", "send_request", "TEMPLATES", "get_stream_chunks", "get_token_list"]
__all__ = [
"build_request_payload",
"send_request",
"TEMPLATES",
"get_stream_chunks",
"get_token_list",
"get_logprobs_list",
"get_probs_list",
]

# 检查环境变量是否存在
URL = os.environ.get("URL")
Expand Down
39 changes: 39 additions & 0 deletions test/ce/server/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python

import json
import math

import requests
from core import TEMPLATES, base_logger
Expand Down Expand Up @@ -97,3 +98,41 @@ def get_token_list(response):

base_logger.info(f"Token List:{token_list}")
return token_list


def get_logprobs_list(response):
"""解析 response 中的 token 文本列表"""
logprobs_list = []

try:
content_logprobs = response["choices"][0]["logprobs"]["content"]
except (KeyError, IndexError, TypeError) as e:
base_logger.error(f"解析失败:{e}")
return []

for token_info in content_logprobs:
token = token_info.get("logprob")
if token is not None:
logprobs_list.append(token)

base_logger.info(f"Logprobs List:{logprobs_list}")
return logprobs_list


def get_probs_list(response):
"""解析 response 中的 token 文本列表"""
probs_list = []

try:
content_logprobs = response["choices"][0]["logprobs"]["content"]
except (KeyError, IndexError, TypeError) as e:
base_logger.error(f"解析失败:{e}")
return []

for token_info in content_logprobs:
token = token_info.get("logprob")
if token is not None:
probs_list.append(math.exp(token))

base_logger.info(f"probs List:{probs_list}")
return probs_list
58 changes: 56 additions & 2 deletions test/ce/server/test_base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@

import json

from core import TEMPLATE, URL, build_request_payload, get_token_list, send_request
from core import (
TEMPLATE,
URL,
build_request_payload,
get_probs_list,
get_token_list,
send_request,
)


def test_stream_response():
Expand Down Expand Up @@ -150,7 +157,7 @@ def test_multi_turn_conversation():


def test_bad_words_filtering():
banned_tokens = [""]
banned_tokens = ["香蕉"]

data = {
"stream": False,
Expand Down Expand Up @@ -221,3 +228,50 @@ def test_bad_words_filtering1():
assert word in token_list, f"'{word}' 应出现在生成结果中"

print("test_bad_words_filtering1 正例验证通过")


def test_repetition_early_stop():
"""
用于验证 repetition early stop 功能是否生效:
设置 window_size=6,threshold=0.93,输入内容设计成易重复,观察模型是否提前截断输出。
threshold = 0.93
window_size = 6 这个必须是启动模型的时候加上这个参数 负责不能用!!!!
"""

data = {
"stream": False,
"messages": [
{"role": "user", "content": "输出'我爱吃果冻' 10次"},
],
"max_tokens": 10000,
"temperature": 0.8,
"top_p": 0,
}

payload = build_request_payload(TEMPLATE, data)
response = send_request(URL, payload).json()
content = response["choices"][0]["message"]["content"]

print("🧪 repetition early stop 输出内容:\n", content)
probs_list = get_probs_list(response)

threshold = 0.93
window_size = 6

assert len(probs_list) >= window_size, "列表长度不足 window_size"

# 条件 1:末尾 6 个都 > threshold
tail = probs_list[-window_size:]
assert all(v > threshold for v in tail), "末尾 window_size 个数不全大于阈值"

# 条件 2:前面不能有连续 >=6 个值 > threshold
head = probs_list[:-window_size]
count = 0
for v in head:
if v > threshold:
count += 1
assert count < window_size, f"在末尾之前出现了连续 {count} 个大于阈值的数"
else:
count = 0

print("repetition early stop 功能验证通过")
Loading