diff --git a/CHANGELOG.md b/CHANGELOG.md index 26003a8fd..6020cfb85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## 4.21 +- [#287] (https://github.com/cohere-ai/cohere-python/pull/287) + - Remove deprecated chat "query" parameter including inside chat_history parameter + - Support event-type for chat streaming + ## 4.20.2 - [#284] (https://github.com/cohere-ai/cohere-python/pull/284) - Rename dataset urls to download_urls diff --git a/cohere/client.py b/cohere/client.py index 1340ff3b5..a21b23e83 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -227,7 +227,6 @@ def generate( def chat( self, message: Optional[str] = None, - query: Optional[str] = None, conversation_id: Optional[str] = "", model: Optional[str] = None, return_chatlog: Optional[bool] = False, @@ -246,22 +245,26 @@ def chat( """Returns a Chat object with the query reply. Args: - query (str): Deprecated. Use message instead. message (str): The message to send to the chatbot. - conversation_id (str): (Optional) The conversation id to continue the conversation. - model (str): (Optional) The model to use for generating the next reply. - return_chatlog (bool): (Optional) Whether to return the chatlog. - return_prompt (bool): (Optional) Whether to return the prompt. - return_preamble (bool): (Optional) Whether to return the preamble. - chat_history (List[Dict[str, str]]): (Optional) A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state. - preamble_override (str): (Optional) A string to override the preamble. - user_name (str): (Optional) A string to override the username. - temperature (float): (Optional) The temperature to use for the next reply. The higher the temperature, the more random the reply. - max_tokens (int): (Optional) The max tokens generated for the next reply. + stream (bool): Return streaming tokens. + conversation_id (str): (Optional) To store a conversation then create a conversation id and use it for every related request. + + preamble_override (str): (Optional) A string to override the preamble. + chat_history (List[Dict[str, str]]): (Optional) A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state. + + model (str): (Optional) The model to use for generating the response. + temperature (float): (Optional) The temperature to use for the response. The higher the temperature, the more random the response. p (float): (Optional) The nucleus sampling probability. k (float): (Optional) The top-k sampling probability. logit_bias (Dict[int, float]): (Optional) A dictionary of logit bias values to use for the next reply. + max_tokens (int): (Optional) The max tokens generated for the next reply. + + return_chatlog (bool): (Optional) Whether to return the chatlog. + return_prompt (bool): (Optional) Whether to return the prompt. + return_preamble (bool): (Optional) Whether to return the preamble. + + user_name (str): (Optional) A string to override the username. Returns: a Chat object if stream=False, or a StreamingChat object if stream=True @@ -269,7 +272,6 @@ def chat( A simple chat message: >>> res = co.chat(message="Hey! How are you doing today?") >>> print(res.text) - >>> print(res.conversation_id) Continuing a session using a specific model: >>> res = co.chat( >>> message="Hey! How are you doing today?", @@ -295,25 +297,6 @@ def chat( >>> print(res.text) >>> print(res.prompt) """ - if chat_history is not None: - should_warn = True - for entry in chat_history: - if "text" in entry: - entry["message"] = entry["text"] - - if "text" in entry and should_warn: - logger.warning( - "The 'text' parameter is deprecated and will be removed in a future version of this function. " - + "Use 'message' instead.", - ) - should_warn = False - - if query is not None: - logger.warning( - "The chat_history 'text' key is deprecated and will be removed in a future version of this function. " - + "Use 'message' instead.", - ) - message = query json_body = { "message": message, diff --git a/cohere/client_async.py b/cohere/client_async.py index 4ca5ecc6c..d64657391 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -209,7 +209,6 @@ async def generate( async def chat( self, message: Optional[str] = None, - query: Optional[str] = None, conversation_id: Optional[str] = "", model: Optional[str] = None, return_chatlog: Optional[bool] = False, @@ -225,28 +224,8 @@ async def chat( k: Optional[float] = None, logit_bias: Optional[Dict[int, float]] = None, ) -> Union[AsyncChat, StreamingChat]: - if chat_history is not None: - should_warn = True - for entry in chat_history: - if "text" in entry: - entry["message"] = entry["text"] - - if "text" in entry and should_warn: - logger.warning( - "The 'text' parameter is deprecated and will be removed in a future version of this function. " - + "Use 'message' instead.", - ) - should_warn = False - - if query is None and message is None: - raise CohereError("Either 'query' or 'message' must be provided.") - - if query is not None: - logger.warning( - "The 'query' parameter is deprecated and will be removed in a future version of this function. " - + "Use 'message' instead.", - ) - message = query + if message is None: + raise CohereError("'message' must be provided.") json_body = { "message": message, diff --git a/cohere/responses/chat.py b/cohere/responses/chat.py index cfbc984c7..24127b694 100644 --- a/cohere/responses/chat.py +++ b/cohere/responses/chat.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Generator, List, NamedTuple, Optional +from typing import Any, Dict, Generator, List, Optional import requests @@ -25,10 +25,9 @@ def __init__( super().__init__(**kwargs) self.response_id = response_id self.generation_id = generation_id - self.query = message # to be deprecated self.message = message self.text = text - self.conversation_id = conversation_id + self.conversation_id = conversation_id # optional self.prompt = prompt # optional self.chatlog = chatlog # optional self.preamble = preamble # optional @@ -47,7 +46,7 @@ def from_dict(cls, response: Dict[str, Any], message: str, client) -> "Chat": text=response.get("text"), prompt=response.get("prompt"), # optional chatlog=response.get("chatlog"), # optional - preamble=response.get("preamble"), # option + preamble=response.get("preamble"), # optional client=client, token_count=response.get("token_count"), meta=response.get("meta"), @@ -76,7 +75,38 @@ async def respond(self, response: str, max_tokens: int = None) -> "AsyncChat": ) -StreamingText = NamedTuple("StreamingText", [("index", Optional[int]), ("text", str), ("is_finished", bool)]) +class StreamResponse(CohereObject): + def __init__( + self, + is_finished: bool, + index: Optional[int], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.is_finished = is_finished + self.index = index + + +class StreamStart(StreamResponse): + def __init__( + self, + generation_id: str, + conversation_id: Optional[str], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.generation_id = generation_id + self.conversation_id = conversation_id + + +class StreamTextGeneration(StreamResponse): + def __init__( + self, + text: str, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.text = text class StreamingChat(CohereObject): @@ -85,34 +115,47 @@ def __init__(self, response): self.texts = [] self.response_id = None self.conversation_id = None + self.generation_id = None self.preamble = None self.prompt = None self.chatlog = None self.finish_reason = None + self.token_count = None + self.meta = None def _make_response_item(self, index, line) -> Any: streaming_item = json.loads(line) - is_finished = streaming_item.get("is_finished") - text = streaming_item.get("text") - - if not is_finished: - return StreamingText(text=text, is_finished=is_finished, index=index) - - response = streaming_item.get("response") - - if response is None: + event_type = streaming_item.get("event_type") + + if event_type == "stream-start": + self.conversation_id = streaming_item.get("conversation_id") + self.generation_id = streaming_item.get("generation_id") + return StreamStart( + conversation_id=self.conversation_id, generation_id=self.generation_id, is_finished=False, index=index + ) + elif event_type == "text-generation": + text = streaming_item.get("text") + return StreamTextGeneration(text=text, is_finished=False, index=index) + elif event_type == "stream-end": + response = streaming_item.get("response") + self.finish_reason = streaming_item.get("finish_reason") + + if response is None: + return None + + self.response_id = response.get("response_id") + self.conversation_id = response.get("conversation_id") + self.texts = [response.get("text")] + self.generation_id = response.get("generation_id") + self.preamble = response.get("preamble") + self.prompt = response.get("prompt") + self.chatlog = response.get("chatlog") + self.token_count = response.get("token_count") + self.meta = response.get("meta") return None - - self.response_id = response.get("response_id") - self.conversation_id = response.get("conversation_id") - self.preamble = response.get("preamble") - self.prompt = response.get("prompt") - self.chatlog = response.get("chatlog") - self.finish_reason = streaming_item.get("finish_reason") - self.texts = [response.get("text")] return None - def __iter__(self) -> Generator[StreamingText, None, None]: + def __iter__(self) -> Generator[StreamResponse, None, None]: if not isinstance(self.response, requests.Response): raise ValueError("For AsyncClient, use `async for` to iterate through the `StreamingChat`") @@ -121,7 +164,7 @@ def __iter__(self) -> Generator[StreamingText, None, None]: if item is not None: yield item - async def __aiter__(self) -> Generator[StreamingText, None, None]: + async def __aiter__(self) -> Generator[StreamResponse, None, None]: index = 0 async for line in self.response.content: item = self._make_response_item(index, line) diff --git a/pyproject.toml b/pyproject.toml index af29e6b35..c8920bd42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cohere" -version = "4.20.2" +version = "4.21" description = "" authors = ["Cohere"] readme = "README.md" diff --git a/tests/async/test_async_chat.py b/tests/async/test_async_chat.py index b24a27d57..7f98544fd 100644 --- a/tests/async/test_async_chat.py +++ b/tests/async/test_async_chat.py @@ -1,5 +1,7 @@ import pytest +import cohere + @pytest.mark.asyncio async def test_async_multi_replies(async_client): @@ -36,12 +38,17 @@ async def test_async_chat_stream(async_client): expected_index = 0 expected_text = "" async for token in res: - if token.text: + if isinstance(token, cohere.responses.chat.StreamStart): + assert token.generation_id is not None + assert not token.is_finished + elif isinstance(token, cohere.responses.chat.StreamTextGeneration): assert isinstance(token.text, str) assert len(token.text) > 0 - assert token.index == expected_index - expected_text += token.text + assert not token.is_finished + + assert isinstance(token.index, int) + assert token.index == expected_index expected_index += 1 assert res.texts == [expected_text] diff --git a/tests/sync/test_chat.py b/tests/sync/test_chat.py index 30843ba2f..b3e9e0fdf 100644 --- a/tests/sync/test_chat.py +++ b/tests/sync/test_chat.py @@ -91,14 +91,16 @@ def test_stream(self): expected_index = 0 expected_text = "" for token in prediction: - if token.text: + if isinstance(token, cohere.responses.chat.StreamStart): + self.assertIsNotNone(token.generation_id) + self.assertFalse(token.is_finished) + elif isinstance(token, cohere.responses.chat.StreamTextGeneration): self.assertIsInstance(token.text, str) self.assertGreater(len(token.text), 0) - - self.assertIsInstance(token.index, int) - self.assertEqual(token.index, expected_index) - expected_text += token.text + self.assertFalse(token.is_finished) + self.assertIsInstance(token.index, int) + self.assertEqual(token.index, expected_index) expected_index += 1 self.assertEqual(prediction.texts, [expected_text])