Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 41 additions & 20 deletions dashscope/aigc/multimodal_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ class Models:
def call(
cls,
model: str,
messages: List,
messages: List = None,
api_key: str = None,
workspace: str = None,
text: str = None,
**kwargs
) -> Union[MultiModalConversationResponse, Generator[
MultiModalConversationResponse, None, None]]:
Expand Down Expand Up @@ -55,6 +56,7 @@ def call(
if None, will retrieve by rule [1].
[1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501
workspace (str): The dashscope workspace id.
text (str): The text to generate.
**kwargs:
stream(bool, `optional`): Enable server-sent events
(ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501
Expand All @@ -68,8 +70,11 @@ def call(
tokens with top_p probability mass. So 0.1 means only
the tokens comprising the top 10% probability mass are
considered[qwen-turbo,bailian-v1].
voice(string, `optional`): The voice name of qwen tts, include 'Cherry'/'Ethan'/'Sunny'/'Dylan' and so on,
you can get the total voice list : https://help.aliyun.com/zh/model-studio/qwen-tts.
top_k(float, `optional`):


Raises:
InvalidInput: The history and auto_history are mutually exclusive.

Expand All @@ -78,18 +83,24 @@ def call(
Generator[MultiModalConversationResponse, None, None]]: If
stream is True, return Generator, otherwise MultiModalConversationResponse.
"""
if (messages is None or not messages):
raise InputRequired('prompt or messages is required!')
if model is None or not model:
raise ModelRequired('Model is required!')
task_group, _ = _get_task_group_and_task(__name__)
msg_copy = copy.deepcopy(messages)
has_upload = cls._preprocess_messages(model, msg_copy, api_key)
if has_upload:
headers = kwargs.pop('headers', {})
headers['X-DashScope-OssResourceResolve'] = 'enable'
kwargs['headers'] = headers
input = {'messages': msg_copy}
input = {}
msg_copy = None

if messages is not None and messages:
msg_copy = copy.deepcopy(messages)
has_upload = cls._preprocess_messages(model, msg_copy, api_key)
if has_upload:
headers = kwargs.pop('headers', {})
headers['X-DashScope-OssResourceResolve'] = 'enable'
kwargs['headers'] = headers

if text is not None and text:
input.update({'text': text})
if msg_copy is not None:
input.update({'messages': msg_copy})
response = super().call(model=model,
task_group=task_group,
task=MultiModalConversation.task,
Expand Down Expand Up @@ -145,9 +156,10 @@ class Models:
async def call(
cls,
model: str,
messages: List,
messages: List = None,
api_key: str = None,
workspace: str = None,
text: str = None,
**kwargs
) -> Union[MultiModalConversationResponse, Generator[
MultiModalConversationResponse, None, None]]:
Expand Down Expand Up @@ -176,6 +188,7 @@ async def call(
if None, will retrieve by rule [1].
[1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501
workspace (str): The dashscope workspace id.
text (str): The text to generate.
**kwargs:
stream(bool, `optional`): Enable server-sent events
(ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501
Expand All @@ -189,6 +202,8 @@ async def call(
tokens with top_p probability mass. So 0.1 means only
the tokens comprising the top 10% probability mass are
considered[qwen-turbo,bailian-v1].
voice(string, `optional`): The voice name of qwen tts, include 'Cherry'/'Ethan'/'Sunny'/'Dylan' and so on,
you can get the total voice list : https://help.aliyun.com/zh/model-studio/qwen-tts.
top_k(float, `optional`):

Raises:
Expand All @@ -199,18 +214,24 @@ async def call(
Generator[MultiModalConversationResponse, None, None]]: If
stream is True, return Generator, otherwise MultiModalConversationResponse.
"""
if (messages is None or not messages):
raise InputRequired('prompt or messages is required!')
if model is None or not model:
raise ModelRequired('Model is required!')
task_group, _ = _get_task_group_and_task(__name__)
msg_copy = copy.deepcopy(messages)
has_upload = cls._preprocess_messages(model, msg_copy, api_key)
if has_upload:
headers = kwargs.pop('headers', {})
headers['X-DashScope-OssResourceResolve'] = 'enable'
kwargs['headers'] = headers
input = {'messages': msg_copy}
input = {}
msg_copy = None

if messages is not None and messages:
msg_copy = copy.deepcopy(messages)
has_upload = cls._preprocess_messages(model, msg_copy, api_key)
if has_upload:
headers = kwargs.pop('headers', {})
headers['X-DashScope-OssResourceResolve'] = 'enable'
kwargs['headers'] = headers

if text is not None and text:
input.update({'text': text})
if msg_copy is not None:
input.update({'messages': msg_copy})
response = await super().call(model=model,
task_group=task_group,
task=AioMultiModalConversation.task,
Expand Down
34 changes: 31 additions & 3 deletions dashscope/api_entities/dashscope_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,26 @@ def __init__(self,
**kwargs)


@dataclass(init=False)
class Audio(DictMixin):
data: str
url: str
id: str
expires_at: int

def __init__(self,
data: str = None,
url: str = None,
id: str = None,
expires_at: int = None,
**kwargs):
super().__init__(data=data,
url=url,
id=id,
expires_at=expires_at,
**kwargs)


@dataclass(init=False)
class GenerationOutput(DictMixin):
text: str
Expand Down Expand Up @@ -217,36 +237,44 @@ def from_api_response(api_response: DashScopeAPIResponse):
@dataclass(init=False)
class MultiModalConversationOutput(DictMixin):
choices: List[Choice]
audio: Audio

def __init__(self,
text: str = None,
finish_reason: str = None,
choices: List[Choice] = None,
audio: Audio = None,
**kwargs):
chs = None
if choices is not None:
chs = []
for choice in choices:
chs.append(Choice(**choice))
if audio is not None:
audio = Audio(**audio)
super().__init__(text=text,
finish_reason=finish_reason,
choices=chs,
audio=audio,
**kwargs)


@dataclass(init=False)
class MultiModalConversationUsage(DictMixin):
input_tokens: int
output_tokens: int
characters: int

# TODO add image usage info.

def __init__(self,
input_tokens: int = 0,
output_tokens: int = 0,
characters: int = 0,
**kwargs):
super().__init__(input_tokens=input_tokens,
output_tokens=output_tokens,
characters=characters,
**kwargs)


Expand Down Expand Up @@ -378,7 +406,7 @@ def is_sentence_end(sentence: Dict[str, Any]) -> bool:
"""
result = False
if sentence is not None and 'end_time' in sentence and sentence[
'end_time'] is not None:
'end_time'] is not None:
result = True
return result

Expand Down Expand Up @@ -445,8 +473,8 @@ class ImageSynthesisOutput(DictMixin):
results: List[ImageSynthesisResult]

def __init__(self,
task_id: str = None,
task_status: str = None,
task_id: str = None,
task_status: str = None,
results: List[ImageSynthesisResult] = [],
**kwargs):
res = []
Expand Down
38 changes: 38 additions & 0 deletions samples/test_qwen_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

import dashscope
import logging

logger = logging.getLogger('dashscope')
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
# create formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# add formatter to ch
console_handler.setFormatter(formatter)

# add ch to logger
logger.addHandler(console_handler)

# switch stream or non-stream mode
use_stream = True

response = dashscope.MultiModalConversation.call(
api_key=os.getenv('DASHSCOPE_API_KEY'),
model="qwen-tts",
text="Today is a wonderful day to build something people love!",
voice="Cherry",
stream=use_stream
)
if use_stream:
# print the audio data in stream mode
for chunk in response:
audio = chunk.output.audio
print("base64 audio data is: {}", chunk.output.audio.data)
if chunk.output.finish_reason == "stop":
print("finish at: {} ", chunk.output.audio.expires_at)
else:
# print the audio url in non-stream mode
print("synthesized audio url is: {}", response.output.audio.url)
print("finish at: {} ", response.output.audio.expires_at)