Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dashscope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from dashscope.files import Files
from dashscope.models import Models
from dashscope.nlp.understanding import Understanding
from dashscope.rerank.text_rerank import TextReRank
from dashscope.rerank.text_rerank import AioTextReRank, TextReRank
from dashscope.threads import (
MessageFile,
Messages,
Expand Down Expand Up @@ -100,6 +100,7 @@
"list_tokenizers",
"Application",
"TextReRank",
"AioTextReRank",
"Assistants",
"Threads",
"Messages",
Expand Down
69 changes: 68 additions & 1 deletion dashscope/rerank/text_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List

from dashscope.api_entities.dashscope_response import ReRankResponse
from dashscope.client.base_api import BaseApi
from dashscope.client.base_api import BaseApi, BaseAioApi
from dashscope.common.error import InputRequired, ModelRequired
from dashscope.common.utils import _get_task_group_and_task

Expand Down Expand Up @@ -76,3 +76,70 @@ def call( # type: ignore[override]
)

return ReRankResponse.from_api_response(response)


class AioTextReRank(BaseAioApi):
task = "text-rerank"
"""Async API for rerank models."""

class Models:
gte_rerank = "gte-rerank"

@classmethod
async def call( # type: ignore[override]
cls,
model: str,
query: str,
documents: List[str],
return_documents: bool = None,
top_n: int = None,
api_key: str = None,
**kwargs,
) -> ReRankResponse:
"""Calling rerank service asynchronously.

Args:
model (str): The model to use.
query (str): The query string.
documents (List[str]): The documents to rank.
return_documents(bool, `optional`): enable return origin documents,
system default is false.
top_n(int, `optional`): how many documents to return, default return # noqa: E501
all the documents.
api_key (str, optional): The DashScope api key. Defaults to None.

Raises:
InputRequired: The query and documents are required.
ModelRequired: The model is required.

Returns:
RerankResponse: The rerank result.
"""

if query is None or documents is None or not documents:
raise InputRequired("query and documents are required!")
if model is None or not model:
raise ModelRequired("Model is required!")
task_group, function = _get_task_group_and_task(__name__)
input = { # pylint: disable=redefined-builtin
"query": query,
"documents": documents,
}
parameters = {}
if return_documents is not None:
parameters["return_documents"] = return_documents
if top_n is not None:
parameters["top_n"] = top_n
parameters = {**parameters, **kwargs}

response = await super().call(
model=model,
task_group=task_group,
task=TextReRank.task,
function=function,
api_key=api_key,
input=input,
**parameters, # type: ignore[arg-type]
)

return ReRankResponse.from_api_response(response)
28 changes: 27 additions & 1 deletion samples/test_text_rerank.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import asyncio
import os

from dashscope import TextReRank
from dashscope import AioTextReRank, TextReRank


def test_text_rerank():
Expand All @@ -30,5 +31,30 @@ def test_text_rerank():
raise


async def test_aio_text_rerank():
"""Test async text rerank API with instruct parameter."""
query = "哈尔滨在哪?"
documents = [
"黑龙江离俄罗斯很近",
"哈尔滨是中国黑龙江省的省会,位于中国东北",
]

try:
response = await AioTextReRank.call(
model=os.getenv("MODEL_NAME"),
query=query,
documents=documents,
return_documents=True,
top_n=5,
instruct="Retrieval document that can answer users query.",
)

print(f"response:\n{response}")

except Exception as e:
raise


if __name__ == "__main__":
test_text_rerank()
asyncio.run(test_aio_text_rerank())