-
Notifications
You must be signed in to change notification settings - Fork 20
feat: support advanced non-incremental output #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
510fb7f
98b3e4a
7163801
29781fd
ee9377d
79ab9cf
1ee1a61
1440118
7cb7958
ee1dc4e
a0ebbc3
d70b2c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,14 +1,16 @@ | ||||||||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | ||||||||||
|
|
||||||||||
| import copy | ||||||||||
| from typing import Generator, List, Union | ||||||||||
| from typing import AsyncGenerator, Generator, List, Union | ||||||||||
|
|
||||||||||
| from dashscope.api_entities.dashscope_response import \ | ||||||||||
| MultiModalConversationResponse | ||||||||||
| from dashscope.client.base_api import BaseAioApi, BaseApi | ||||||||||
| from dashscope.common.error import InputRequired, ModelRequired | ||||||||||
| from dashscope.common.utils import _get_task_group_and_task | ||||||||||
| from dashscope.utils.oss_utils import preprocess_message_element | ||||||||||
| from dashscope.utils.param_utils import ParamUtil | ||||||||||
| from dashscope.utils.message_utils import merge_multimodal_single_response | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class MultiModalConversation(BaseApi): | ||||||||||
|
|
@@ -108,6 +110,16 @@ def call( | |||||||||
| input.update({'language_type': language_type}) | ||||||||||
| if msg_copy is not None: | ||||||||||
| input.update({'messages': msg_copy}) | ||||||||||
|
|
||||||||||
| # Check if we need to merge incremental output | ||||||||||
| is_incremental_output = kwargs.get('incremental_output', None) | ||||||||||
| to_merge_incremental_output = False | ||||||||||
| is_stream = kwargs.get('stream', False) | ||||||||||
| if (ParamUtil.should_modify_incremental_output(model) and | ||||||||||
| is_stream and is_incremental_output is not None and is_incremental_output is False): | ||||||||||
|
Comment on lines
+118
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition
Suggested change
|
||||||||||
| to_merge_incremental_output = True | ||||||||||
| kwargs['incremental_output'] = True | ||||||||||
|
|
||||||||||
| response = super().call(model=model, | ||||||||||
| task_group=task_group, | ||||||||||
| task=MultiModalConversation.task, | ||||||||||
|
|
@@ -116,10 +128,14 @@ def call( | |||||||||
| input=input, | ||||||||||
| workspace=workspace, | ||||||||||
| **kwargs) | ||||||||||
| is_stream = kwargs.get('stream', False) | ||||||||||
| if is_stream: | ||||||||||
| return (MultiModalConversationResponse.from_api_response(rsp) | ||||||||||
| for rsp in response) | ||||||||||
| if to_merge_incremental_output: | ||||||||||
| # Extract n parameter for merge logic | ||||||||||
| n = kwargs.get('n', 1) | ||||||||||
| return cls._merge_multimodal_response(response, n) | ||||||||||
| else: | ||||||||||
| return (MultiModalConversationResponse.from_api_response(rsp) | ||||||||||
| for rsp in response) | ||||||||||
| else: | ||||||||||
| return MultiModalConversationResponse.from_api_response(response) | ||||||||||
|
|
||||||||||
|
|
@@ -149,6 +165,21 @@ def _preprocess_messages(cls, model: str, messages: List[dict], | |||||||||
| has_upload = True | ||||||||||
| return has_upload | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| def _merge_multimodal_response(cls, response, n=1) -> Generator[MultiModalConversationResponse, None, None]: | ||||||||||
| """Merge incremental response chunks to simulate non-incremental output.""" | ||||||||||
| accumulated_data = {} | ||||||||||
|
|
||||||||||
| for rsp in response: | ||||||||||
| parsed_response = MultiModalConversationResponse.from_api_response(rsp) | ||||||||||
| result = merge_multimodal_single_response(parsed_response, accumulated_data, n) | ||||||||||
| if result is True: | ||||||||||
| yield parsed_response | ||||||||||
| elif isinstance(result, list): | ||||||||||
| # Multiple responses to yield (for n>1 non-stop cases) | ||||||||||
| for resp in result: | ||||||||||
| yield resp | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class AioMultiModalConversation(BaseAioApi): | ||||||||||
| """Async MultiModal conversational robot interface. | ||||||||||
|
|
@@ -170,8 +201,8 @@ async def call( | |||||||||
| voice: str = None, | ||||||||||
| language_type: str = None, | ||||||||||
| **kwargs | ||||||||||
| ) -> Union[MultiModalConversationResponse, Generator[ | ||||||||||
| MultiModalConversationResponse, None, None]]: | ||||||||||
| ) -> Union[MultiModalConversationResponse, AsyncGenerator[ | ||||||||||
| MultiModalConversationResponse, None]]: | ||||||||||
| """Call the conversation model service asynchronously. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
|
|
@@ -221,8 +252,8 @@ async def call( | |||||||||
|
|
||||||||||
| Returns: | ||||||||||
| Union[MultiModalConversationResponse, | ||||||||||
| Generator[MultiModalConversationResponse, None, None]]: If | ||||||||||
| stream is True, return Generator, otherwise MultiModalConversationResponse. | ||||||||||
| AsyncGenerator[MultiModalConversationResponse, None]]: If | ||||||||||
| stream is True, return AsyncGenerator, otherwise MultiModalConversationResponse. | ||||||||||
| """ | ||||||||||
| if model is None or not model: | ||||||||||
| raise ModelRequired('Model is required!') | ||||||||||
|
|
@@ -246,6 +277,16 @@ async def call( | |||||||||
| input.update({'language_type': language_type}) | ||||||||||
| if msg_copy is not None: | ||||||||||
| input.update({'messages': msg_copy}) | ||||||||||
|
|
||||||||||
| # Check if we need to merge incremental output | ||||||||||
| is_incremental_output = kwargs.get('incremental_output', None) | ||||||||||
| to_merge_incremental_output = False | ||||||||||
| is_stream = kwargs.get('stream', False) | ||||||||||
| if (ParamUtil.should_modify_incremental_output(model) and | ||||||||||
| is_stream and is_incremental_output is not None and is_incremental_output is False): | ||||||||||
| to_merge_incremental_output = True | ||||||||||
| kwargs['incremental_output'] = True | ||||||||||
|
|
||||||||||
| response = await super().call(model=model, | ||||||||||
| task_group=task_group, | ||||||||||
| task=AioMultiModalConversation.task, | ||||||||||
|
|
@@ -254,10 +295,13 @@ async def call( | |||||||||
| input=input, | ||||||||||
| workspace=workspace, | ||||||||||
| **kwargs) | ||||||||||
| is_stream = kwargs.get('stream', False) | ||||||||||
| if is_stream: | ||||||||||
| return (MultiModalConversationResponse.from_api_response(rsp) | ||||||||||
| async for rsp in response) | ||||||||||
| if to_merge_incremental_output: | ||||||||||
| # Extract n parameter for merge logic | ||||||||||
| n = kwargs.get('n', 1) | ||||||||||
| return cls._merge_multimodal_response(response, n) | ||||||||||
| else: | ||||||||||
| return cls._stream_responses(response) | ||||||||||
| else: | ||||||||||
| return MultiModalConversationResponse.from_api_response(response) | ||||||||||
|
|
||||||||||
|
|
@@ -286,3 +330,27 @@ def _preprocess_messages(cls, model: str, messages: List[dict], | |||||||||
| if is_upload and not has_upload: | ||||||||||
| has_upload = True | ||||||||||
| return has_upload | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| async def _stream_responses(cls, response) -> AsyncGenerator[MultiModalConversationResponse, None]: | ||||||||||
| """Convert async response stream to MultiModalConversationResponse stream.""" | ||||||||||
| # Type hint: when stream=True, response is actually an AsyncIterable | ||||||||||
| async for rsp in response: # type: ignore | ||||||||||
| yield MultiModalConversationResponse.from_api_response(rsp) | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| async def _merge_multimodal_response(cls, response, n=1) -> AsyncGenerator[MultiModalConversationResponse, None]: | ||||||||||
| """Async version of merge incremental response chunks.""" | ||||||||||
| accumulated_data = {} | ||||||||||
|
|
||||||||||
| async for rsp in response: | ||||||||||
| parsed_response = MultiModalConversationResponse.from_api_response(rsp) | ||||||||||
| result = merge_multimodal_single_response(parsed_response, accumulated_data, n) | ||||||||||
| if result is True: | ||||||||||
| yield parsed_response | ||||||||||
| elif isinstance(result, list): | ||||||||||
| # Multiple responses to yield (for n>1 non-stop cases) | ||||||||||
| for resp in result: | ||||||||||
| yield resp | ||||||||||
|
|
||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code to determine if response merging is needed is identical to the logic in the synchronous
Generation.callmethod (lines 142-149). To improve maintainability and reduce code duplication, consider extracting this logic into a shared private helper function that bothcallandacallcan use.