From 37f4412d43c8ecce415b287a6d9e8e5874719016 Mon Sep 17 00:00:00 2001 From: Hugo Date: Wed, 25 Mar 2026 15:11:26 +0000 Subject: [PATCH] add AioTextReRank to support async rerank --- dashscope/__init__.py | 3 +- dashscope/rerank/text_rerank.py | 69 ++++++++++++++++++++++++++++++++- samples/test_text_rerank.py | 28 ++++++++++++- 3 files changed, 97 insertions(+), 3 deletions(-) diff --git a/dashscope/__init__.py b/dashscope/__init__.py index 744b269..2d0bbc9 100644 --- a/dashscope/__init__.py +++ b/dashscope/__init__.py @@ -42,7 +42,7 @@ from dashscope.files import Files from dashscope.models import Models from dashscope.nlp.understanding import Understanding -from dashscope.rerank.text_rerank import TextReRank +from dashscope.rerank.text_rerank import AioTextReRank, TextReRank from dashscope.threads import ( MessageFile, Messages, @@ -100,6 +100,7 @@ "list_tokenizers", "Application", "TextReRank", + "AioTextReRank", "Assistants", "Threads", "Messages", diff --git a/dashscope/rerank/text_rerank.py b/dashscope/rerank/text_rerank.py index c38bf1d..738c436 100644 --- a/dashscope/rerank/text_rerank.py +++ b/dashscope/rerank/text_rerank.py @@ -4,7 +4,7 @@ from typing import List from dashscope.api_entities.dashscope_response import ReRankResponse -from dashscope.client.base_api import BaseApi +from dashscope.client.base_api import BaseApi, BaseAioApi from dashscope.common.error import InputRequired, ModelRequired from dashscope.common.utils import _get_task_group_and_task @@ -76,3 +76,70 @@ def call( # type: ignore[override] ) return ReRankResponse.from_api_response(response) + + +class AioTextReRank(BaseAioApi): + task = "text-rerank" + """Async API for rerank models.""" + + class Models: + gte_rerank = "gte-rerank" + + @classmethod + async def call( # type: ignore[override] + cls, + model: str, + query: str, + documents: List[str], + return_documents: bool = None, + top_n: int = None, + api_key: str = None, + **kwargs, + ) -> ReRankResponse: + """Calling rerank service asynchronously. + + Args: + model (str): The model to use. + query (str): The query string. + documents (List[str]): The documents to rank. + return_documents(bool, `optional`): enable return origin documents, + system default is false. + top_n(int, `optional`): how many documents to return, default return # noqa: E501 + all the documents. + api_key (str, optional): The DashScope api key. Defaults to None. + + Raises: + InputRequired: The query and documents are required. + ModelRequired: The model is required. + + Returns: + RerankResponse: The rerank result. + """ + + if query is None or documents is None or not documents: + raise InputRequired("query and documents are required!") + if model is None or not model: + raise ModelRequired("Model is required!") + task_group, function = _get_task_group_and_task(__name__) + input = { # pylint: disable=redefined-builtin + "query": query, + "documents": documents, + } + parameters = {} + if return_documents is not None: + parameters["return_documents"] = return_documents + if top_n is not None: + parameters["top_n"] = top_n + parameters = {**parameters, **kwargs} + + response = await super().call( + model=model, + task_group=task_group, + task=TextReRank.task, + function=function, + api_key=api_key, + input=input, + **parameters, # type: ignore[arg-type] + ) + + return ReRankResponse.from_api_response(response) diff --git a/samples/test_text_rerank.py b/samples/test_text_rerank.py index cc19abe..abf1dd7 100644 --- a/samples/test_text_rerank.py +++ b/samples/test_text_rerank.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio import os -from dashscope import TextReRank +from dashscope import AioTextReRank, TextReRank def test_text_rerank(): @@ -30,5 +31,30 @@ def test_text_rerank(): raise +async def test_aio_text_rerank(): + """Test async text rerank API with instruct parameter.""" + query = "哈尔滨在哪?" + documents = [ + "黑龙江离俄罗斯很近", + "哈尔滨是中国黑龙江省的省会,位于中国东北", + ] + + try: + response = await AioTextReRank.call( + model=os.getenv("MODEL_NAME"), + query=query, + documents=documents, + return_documents=True, + top_n=5, + instruct="Retrieval document that can answer users query.", + ) + + print(f"response:\n{response}") + + except Exception as e: + raise + + if __name__ == "__main__": test_text_rerank() + asyncio.run(test_aio_text_rerank())