Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 75 additions & 10 deletions dashscope/aigc/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import copy
import json
from typing import Any, Dict, Generator, List, Union
from typing import Any, Dict, Generator, List, Union, AsyncGenerator

from dashscope.api_entities.dashscope_response import (GenerationResponse,
Message, Role)
Expand All @@ -13,6 +13,8 @@
from dashscope.common.error import InputRequired, ModelRequired
from dashscope.common.logging import logger
from dashscope.common.utils import _get_task_group_and_task
from dashscope.utils.param_utils import ParamUtil
from dashscope.utils.message_utils import merge_single_response


class Generation(BaseApi):
Expand Down Expand Up @@ -137,6 +139,16 @@ def call(
kwargs['headers'] = headers
input, parameters = cls._build_input_parameters(
model, prompt, history, messages, **kwargs)

is_stream = parameters.get('stream', False)
# Check if we need to merge incremental output
is_incremental_output = kwargs.get('incremental_output', None)
to_merge_incremental_output = False
if (ParamUtil.should_modify_incremental_output(model) and
is_stream and is_incremental_output is False):
to_merge_incremental_output = True
parameters['incremental_output'] = True

response = super().call(model=model,
task_group=task_group,
task=Generation.task,
Expand All @@ -145,10 +157,14 @@ def call(
input=input,
workspace=workspace,
**parameters)
is_stream = kwargs.get('stream', False)
if is_stream:
return (GenerationResponse.from_api_response(rsp)
for rsp in response)
if to_merge_incremental_output:
# Extract n parameter for merge logic
n = parameters.get('n', 1)
return cls._merge_generation_response(response, n)
else:
return (GenerationResponse.from_api_response(rsp)
for rsp in response)
else:
return GenerationResponse.from_api_response(response)

Expand Down Expand Up @@ -191,6 +207,20 @@ def _build_input_parameters(cls, model, prompt, history, messages,

return input, {**parameters, **kwargs}

@classmethod
def _merge_generation_response(cls, response, n=1) -> Generator[GenerationResponse, None, None]:
"""Merge incremental response chunks to simulate non-incremental output."""
accumulated_data = {}
for rsp in response:
parsed_response = GenerationResponse.from_api_response(rsp)
result = merge_single_response(parsed_response, accumulated_data, n)
if result is True:
yield parsed_response
elif isinstance(result, list):
# Multiple responses to yield (for n>1 non-stop cases)
for resp in result:
yield resp


class AioGeneration(BaseAioApi):
task = 'text-generation'
Expand Down Expand Up @@ -220,7 +250,7 @@ async def call(
plugins: Union[str, Dict[str, Any]] = None,
workspace: str = None,
**kwargs
) -> Union[GenerationResponse, Generator[GenerationResponse, None, None]]:
) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]:
"""Call generation model service.

Args:
Expand Down Expand Up @@ -296,8 +326,8 @@ async def call(

Returns:
Union[GenerationResponse,
Generator[GenerationResponse, None, None]]: If
stream is True, return Generator, otherwise GenerationResponse.
AsyncGenerator[GenerationResponse, None]]: If
stream is True, return AsyncGenerator, otherwise GenerationResponse.
"""
if (prompt is None or not prompt) and (messages is None
or not messages):
Expand All @@ -314,6 +344,16 @@ async def call(
kwargs['headers'] = headers
input, parameters = Generation._build_input_parameters(
model, prompt, history, messages, **kwargs)

is_stream = parameters.get('stream', False)
# Check if we need to merge incremental output
is_incremental_output = kwargs.get('incremental_output', None)
to_merge_incremental_output = False
if (ParamUtil.should_modify_incremental_output(model) and
is_stream and is_incremental_output is False):
to_merge_incremental_output = True
parameters['incremental_output'] = True
Comment on lines +348 to +355
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code to determine if response merging is needed is identical to the logic in the synchronous Generation.call method (lines 142-149). To improve maintainability and reduce code duplication, consider extracting this logic into a shared private helper function that both call and acall can use.


response = await super().call(model=model,
task_group=task_group,
task=Generation.task,
Expand All @@ -322,9 +362,34 @@ async def call(
input=input,
workspace=workspace,
**parameters)
is_stream = kwargs.get('stream', False)
if is_stream:
return (GenerationResponse.from_api_response(rsp)
async for rsp in response)
if to_merge_incremental_output:
# Extract n parameter for merge logic
n = parameters.get('n', 1)
return cls._merge_generation_response(response, n)
else:
return cls._stream_responses(response)
else:
return GenerationResponse.from_api_response(response)

@classmethod
async def _stream_responses(cls, response) -> AsyncGenerator[GenerationResponse, None]:
"""Convert async response stream to GenerationResponse stream."""
# Type hint: when stream=True, response is actually an AsyncIterable
async for rsp in response: # type: ignore
yield GenerationResponse.from_api_response(rsp)

@classmethod
async def _merge_generation_response(cls, response, n=1) -> AsyncGenerator[GenerationResponse, None]:
"""Async version of merge incremental response chunks."""
accumulated_data = {}

async for rsp in response: # type: ignore
parsed_response = GenerationResponse.from_api_response(rsp)
result = merge_single_response(parsed_response, accumulated_data, n)
if result is True:
yield parsed_response
elif isinstance(result, list):
# Multiple responses to yield (for n>1 non-stop cases)
for resp in result:
yield resp
90 changes: 79 additions & 11 deletions dashscope/aigc/multimodal_conversation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import copy
from typing import Generator, List, Union
from typing import AsyncGenerator, Generator, List, Union

from dashscope.api_entities.dashscope_response import \
MultiModalConversationResponse
from dashscope.client.base_api import BaseAioApi, BaseApi
from dashscope.common.error import InputRequired, ModelRequired
from dashscope.common.utils import _get_task_group_and_task
from dashscope.utils.oss_utils import preprocess_message_element
from dashscope.utils.param_utils import ParamUtil
from dashscope.utils.message_utils import merge_multimodal_single_response


class MultiModalConversation(BaseApi):
Expand Down Expand Up @@ -108,6 +110,16 @@ def call(
input.update({'language_type': language_type})
if msg_copy is not None:
input.update({'messages': msg_copy})

# Check if we need to merge incremental output
is_incremental_output = kwargs.get('incremental_output', None)
to_merge_incremental_output = False
is_stream = kwargs.get('stream', False)
if (ParamUtil.should_modify_incremental_output(model) and
is_stream and is_incremental_output is not None and is_incremental_output is False):
Comment on lines +118 to +119
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition is_incremental_output is not None and is_incremental_output is False is unnecessarily verbose. The is not None check is redundant because if is_incremental_output is False, it cannot be None. This can be simplified. This same redundant check is also present in AioMultiModalConversation.call on lines 277-278.

Suggested change
if (ParamUtil.should_modify_incremental_output(model) and
is_stream and is_incremental_output is not None and is_incremental_output is False):
if (ParamUtil.should_modify_incremental_output(model) and
is_stream and is_incremental_output is False):

to_merge_incremental_output = True
kwargs['incremental_output'] = True

response = super().call(model=model,
task_group=task_group,
task=MultiModalConversation.task,
Expand All @@ -116,10 +128,14 @@ def call(
input=input,
workspace=workspace,
**kwargs)
is_stream = kwargs.get('stream', False)
if is_stream:
return (MultiModalConversationResponse.from_api_response(rsp)
for rsp in response)
if to_merge_incremental_output:
# Extract n parameter for merge logic
n = kwargs.get('n', 1)
return cls._merge_multimodal_response(response, n)
else:
return (MultiModalConversationResponse.from_api_response(rsp)
for rsp in response)
else:
return MultiModalConversationResponse.from_api_response(response)

Expand Down Expand Up @@ -149,6 +165,21 @@ def _preprocess_messages(cls, model: str, messages: List[dict],
has_upload = True
return has_upload

@classmethod
def _merge_multimodal_response(cls, response, n=1) -> Generator[MultiModalConversationResponse, None, None]:
"""Merge incremental response chunks to simulate non-incremental output."""
accumulated_data = {}

for rsp in response:
parsed_response = MultiModalConversationResponse.from_api_response(rsp)
result = merge_multimodal_single_response(parsed_response, accumulated_data, n)
if result is True:
yield parsed_response
elif isinstance(result, list):
# Multiple responses to yield (for n>1 non-stop cases)
for resp in result:
yield resp


class AioMultiModalConversation(BaseAioApi):
"""Async MultiModal conversational robot interface.
Expand All @@ -170,8 +201,8 @@ async def call(
voice: str = None,
language_type: str = None,
**kwargs
) -> Union[MultiModalConversationResponse, Generator[
MultiModalConversationResponse, None, None]]:
) -> Union[MultiModalConversationResponse, AsyncGenerator[
MultiModalConversationResponse, None]]:
"""Call the conversation model service asynchronously.

Args:
Expand Down Expand Up @@ -221,8 +252,8 @@ async def call(

Returns:
Union[MultiModalConversationResponse,
Generator[MultiModalConversationResponse, None, None]]: If
stream is True, return Generator, otherwise MultiModalConversationResponse.
AsyncGenerator[MultiModalConversationResponse, None]]: If
stream is True, return AsyncGenerator, otherwise MultiModalConversationResponse.
"""
if model is None or not model:
raise ModelRequired('Model is required!')
Expand All @@ -246,6 +277,16 @@ async def call(
input.update({'language_type': language_type})
if msg_copy is not None:
input.update({'messages': msg_copy})

# Check if we need to merge incremental output
is_incremental_output = kwargs.get('incremental_output', None)
to_merge_incremental_output = False
is_stream = kwargs.get('stream', False)
if (ParamUtil.should_modify_incremental_output(model) and
is_stream and is_incremental_output is not None and is_incremental_output is False):
to_merge_incremental_output = True
kwargs['incremental_output'] = True

response = await super().call(model=model,
task_group=task_group,
task=AioMultiModalConversation.task,
Expand All @@ -254,10 +295,13 @@ async def call(
input=input,
workspace=workspace,
**kwargs)
is_stream = kwargs.get('stream', False)
if is_stream:
return (MultiModalConversationResponse.from_api_response(rsp)
async for rsp in response)
if to_merge_incremental_output:
# Extract n parameter for merge logic
n = kwargs.get('n', 1)
return cls._merge_multimodal_response(response, n)
else:
return cls._stream_responses(response)
else:
return MultiModalConversationResponse.from_api_response(response)

Expand Down Expand Up @@ -286,3 +330,27 @@ def _preprocess_messages(cls, model: str, messages: List[dict],
if is_upload and not has_upload:
has_upload = True
return has_upload

@classmethod
async def _stream_responses(cls, response) -> AsyncGenerator[MultiModalConversationResponse, None]:
"""Convert async response stream to MultiModalConversationResponse stream."""
# Type hint: when stream=True, response is actually an AsyncIterable
async for rsp in response: # type: ignore
yield MultiModalConversationResponse.from_api_response(rsp)

@classmethod
async def _merge_multimodal_response(cls, response, n=1) -> AsyncGenerator[MultiModalConversationResponse, None]:
"""Async version of merge incremental response chunks."""
accumulated_data = {}

async for rsp in response:
parsed_response = MultiModalConversationResponse.from_api_response(rsp)
result = merge_multimodal_single_response(parsed_response, accumulated_data, n)
if result is True:
yield parsed_response
elif isinstance(result, list):
# Multiple responses to yield (for n>1 non-stop cases)
for resp in result:
yield resp


Loading