Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions samples/test_text_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import sys

# Add the project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from dashscope import TextReRank


def test_text_rerank():
"""Test text rerank API with instruct parameter."""
query = "哈尔滨在哪?"
documents = [
"黑龙江离俄罗斯很近",
"哈尔滨是中国黑龙江省的省会,位于中国东北"
]

try:
response = TextReRank.call(
model=os.getenv("MODEL_NAME"),
query=query,
documents=documents,
return_documents=True,
top_n=5,
instruct="Retrieval document that can answer users query."
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There appears to be a grammatical error and a typo in the instruct string. "Retrieval" is a noun, but an instruction should typically start with a verb like "Retrieve". Also, "document" should probably be plural "documents", and "users" should be possessive "user's". Correcting this will improve clarity and may lead to better model performance.

Suggested change
instruct="Retrieval document that can answer users query."
instruct="Retrieve documents that can answer the user's query."

)

print(f'response: {response}')

print("\n✅ Test passed! All assertions successful.")
Comment on lines +31 to +32
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function is named test_text_rerank and the success message claims "All assertions successful", but there are no assertions in the code. This makes it an example script rather than a test. To ensure the API is working correctly, you should add assertions to validate the structure and content of the response.

Suggested change
print("\n✅ Test passed! All assertions successful.")
assert response.output, "Response should have an output."
assert 'results' in response.output, "Output should contain 'results'."
results = response.output['results']
assert isinstance(results, list)
assert len(results) <= len(documents)
assert all('relevance_score' in r for r in results)
assert all('index' in r for r in results)
assert all('document' in r for r in results), "document should be returned when return_documents=True"
print("\n✅ Test passed! All assertions successful.")


except Exception as e:
print(f"❌ Test failed with error: {str(e)}")
raise

if __name__ == "__main__":
# Load environment variables if .env file exists
try:
with open(os.path.expanduser('~/.env'), 'r') as f:
for line in f:
if line.strip() and not line.startswith('#'):
key, value = line.strip().split('=', 1)
os.environ[key] = value
except FileNotFoundError:
print("No .env file found, using system environment variables")
Comment on lines +40 to +47
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation loads environment variables from ~/.env, which is a user-level global configuration file. For project-specific settings, it's more conventional to use a .env file in the project's root directory. This makes the project more self-contained and easier for other developers to set up. I suggest modifying this logic to load from the project root and also improving the error message to be more specific.

Suggested change
try:
with open(os.path.expanduser('~/.env'), 'r') as f:
for line in f:
if line.strip() and not line.startswith('#'):
key, value = line.strip().split('=', 1)
os.environ[key] = value
except FileNotFoundError:
print("No .env file found, using system environment variables")
try:
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
dotenv_path = os.path.join(project_root, '.env')
with open(dotenv_path, 'r') as f:
for line in f:
if line.strip() and not line.startswith('#'):
key, value = line.strip().split('=', 1)
os.environ[key] = value
except FileNotFoundError:
print(f"No .env file found at '{dotenv_path}', using system environment variables")


# Run tests
test_text_rerank()

print("\n🎉 All tests completed successfully!")