From 8aa0857d7f432ba8c9b3d54f5c3b5a58a80d2261 Mon Sep 17 00:00:00 2001 From: daniel Date: Wed, 27 Sep 2023 16:56:17 -0400 Subject: [PATCH 1/4] start adjusting tests and client for chatlog->chathistory change. 3 tests still failing. --- cohere/client.py | 10 ++-- cohere/client_async.py | 4 +- cohere/responses/chat.py | 101 +++++++++++++++++++-------------------- tests/sync/test_chat.py | 20 ++++---- 4 files changed, 68 insertions(+), 67 deletions(-) diff --git a/cohere/client.py b/cohere/client.py index 2f17bf357..e389d5e18 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -229,7 +229,7 @@ def chat( message: Optional[str] = None, conversation_id: Optional[str] = "", model: Optional[str] = None, - return_chatlog: Optional[bool] = False, + return_chat_history: Optional[bool] = False, return_prompt: Optional[bool] = False, return_preamble: Optional[bool] = False, chat_history: Optional[List[Dict[str, str]]] = None, @@ -265,7 +265,7 @@ def chat( logit_bias (Dict[int, float]): (Optional) A dictionary of logit bias values to use for the next reply. max_tokens (int): (Optional) The max tokens generated for the next reply. - return_chatlog (bool): (Optional) Whether to return the chatlog. + return_chat_history (bool): (Optional) Whether to return the chat history. return_prompt (bool): (Optional) Whether to return the prompt. return_preamble (bool): (Optional) Whether to return the preamble. @@ -302,9 +302,9 @@ def chat( >>> message="Hey! How are you doing today?", >>> conversation_id="1234", >>> model="command", - >>> return_chatlog=True) + >>> return_chat_history=True) >>> print(res.text) - >>> print(res.chatlog) + >>> print(res.chat_history) Streaming chat: >>> res = co.chat( >>> message="Hey! How are you doing today?", @@ -360,7 +360,7 @@ def chat( "message": message, "conversation_id": conversation_id, "model": model, - "return_chatlog": return_chatlog, + "return_chat_history": return_chat_history, "return_prompt": return_prompt, "return_preamble": return_preamble, "chat_history": chat_history, diff --git a/cohere/client_async.py b/cohere/client_async.py index 9d952af75..25c895422 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -213,7 +213,7 @@ async def chat( message: Optional[str] = None, conversation_id: Optional[str] = "", model: Optional[str] = None, - return_chatlog: Optional[bool] = False, + return_chat_history: Optional[bool] = False, return_prompt: Optional[bool] = False, return_preamble: Optional[bool] = False, chat_history: Optional[List[Dict[str, str]]] = None, @@ -238,7 +238,7 @@ async def chat( "message": message, "conversation_id": conversation_id, "model": model, - "return_chatlog": return_chatlog, + "return_chat_history": return_chat_history, "return_prompt": return_prompt, "return_preamble": return_preamble, "chat_history": chat_history, diff --git a/cohere/responses/chat.py b/cohere/responses/chat.py index bc6ae5bcb..5dfe155c0 100644 --- a/cohere/responses/chat.py +++ b/cohere/responses/chat.py @@ -1,32 +1,31 @@ import json +import requests from enum import Enum from typing import Any, Dict, Generator, List, Optional, Union -import requests - from cohere.responses.base import CohereObject class Chat(CohereObject): def __init__( - self, - response_id: str, - generation_id: str, - message: str, - text: str, - conversation_id: str, - meta: Optional[Dict[str, Any]] = None, - prompt: Optional[str] = None, - chatlog: Optional[List[Dict[str, str]]] = None, - preamble: Optional[str] = None, - token_count: Optional[Dict[str, int]] = None, - is_search_required: Optional[bool] = None, - citations: Optional[List[Dict[str, Any]]] = None, - documents: Optional[List[Dict[str, Any]]] = None, - search_results: Optional[List[Dict[str, Any]]] = None, - search_queries: Optional[List[Dict[str, Any]]] = None, - client=None, - **kwargs, + self, + response_id: str, + generation_id: str, + message: str, + text: str, + conversation_id: str, + meta: Optional[Dict[str, Any]] = None, + prompt: Optional[str] = None, + chat_history: Optional[List[Dict[str, Any]]] = None, + preamble: Optional[str] = None, + token_count: Optional[Dict[str, int]] = None, + is_search_required: Optional[bool] = None, + citations: Optional[List[Dict[str, Any]]] = None, + documents: Optional[List[Dict[str, Any]]] = None, + search_results: Optional[List[Dict[str, Any]]] = None, + search_queries: Optional[List[Dict[str, Any]]] = None, + client=None, + **kwargs, ) -> None: super().__init__(**kwargs) self.response_id = response_id @@ -35,7 +34,7 @@ def __init__( self.text = text self.conversation_id = conversation_id # optional self.prompt = prompt # optional - self.chatlog = chatlog # optional + self.chat_history = chat_history # optional self.preamble = preamble # optional self.client = client self.token_count = token_count @@ -56,7 +55,7 @@ def from_dict(cls, response: Dict[str, Any], message: str, client) -> "Chat": conversation_id=response.get("conversation_id"), # optional text=response.get("text"), prompt=response.get("prompt"), # optional - chatlog=response.get("chatlog"), # optional + chat_history=response.get("chat_history"), # optional preamble=response.get("preamble"), # optional client=client, token_count=response.get("token_count"), @@ -72,7 +71,7 @@ def respond(self, response: str, max_tokens: int = None) -> "Chat": return self.client.chat( message=response, conversation_id=self.conversation_id, - return_chatlog=self.chatlog is not None, + return_chat_history=self.chat_history is not None, return_prompt=self.prompt is not None, return_preamble=self.preamble is not None, max_tokens=max_tokens, @@ -84,7 +83,7 @@ async def respond(self, response: str, max_tokens: int = None) -> "AsyncChat": return await self.client.chat( message=response, conversation_id=self.conversation_id, - return_chatlog=self.chatlog is not None, + return_chat_history=self.chat_history is not None, return_prompt=self.prompt is not None, return_preamble=self.preamble is not None, max_tokens=max_tokens, @@ -102,11 +101,11 @@ class StreamEvent(str, Enum): class StreamResponse(CohereObject): def __init__( - self, - is_finished: bool, - event_type: Union[StreamEvent, str], - index: Optional[int], - **kwargs, + self, + is_finished: bool, + event_type: Union[StreamEvent, str], + index: Optional[int], + **kwargs, ) -> None: super().__init__(**kwargs) self.is_finished = is_finished @@ -116,10 +115,10 @@ def __init__( class StreamStart(StreamResponse): def __init__( - self, - generation_id: str, - conversation_id: Optional[str], - **kwargs, + self, + generation_id: str, + conversation_id: Optional[str], + **kwargs, ) -> None: super().__init__(**kwargs) self.generation_id = generation_id @@ -128,9 +127,9 @@ def __init__( class StreamTextGeneration(StreamResponse): def __init__( - self, - text: str, - **kwargs, + self, + text: str, + **kwargs, ) -> None: super().__init__(**kwargs) self.text = text @@ -138,9 +137,9 @@ def __init__( class StreamCitationGeneration(StreamResponse): def __init__( - self, - citations: Optional[List[Dict[str, Any]]], - **kwargs, + self, + citations: Optional[List[Dict[str, Any]]], + **kwargs, ) -> None: super().__init__(**kwargs) self.citations = citations @@ -148,9 +147,9 @@ def __init__( class StreamQueryGeneration(StreamResponse): def __init__( - self, - search_queries: Optional[List[Dict[str, Any]]], - **kwargs, + self, + search_queries: Optional[List[Dict[str, Any]]], + **kwargs, ) -> None: super().__init__(**kwargs) self.search_queries = search_queries @@ -158,10 +157,10 @@ def __init__( class StreamSearchResults(StreamResponse): def __init__( - self, - search_results: Optional[List[Dict[str, Any]]], - documents: Optional[List[Dict[str, Any]]], - **kwargs, + self, + search_results: Optional[List[Dict[str, Any]]], + documents: Optional[List[Dict[str, Any]]], + **kwargs, ) -> None: super().__init__(**kwargs) self.search_results = search_results @@ -170,9 +169,9 @@ def __init__( class StreamEnd(StreamResponse): def __init__( - self, - finish_reason: str, - **kwargs, + self, + finish_reason: str, + **kwargs, ) -> None: super().__init__(**kwargs) self.finish_reason = finish_reason @@ -187,7 +186,7 @@ def __init__(self, response): self.generation_id = None self.preamble = None self.prompt = None - self.chatlog = None + self.chat_history = None self.finish_reason = None self.token_count = None self.meta = None @@ -246,7 +245,7 @@ def _make_response_item(self, index, line) -> Any: self.generation_id = response.get("generation_id") self.preamble = response.get("preamble") self.prompt = response.get("prompt") - self.chatlog = response.get("chatlog") + self.chat_history = response.get("chat_history") self.token_count = response.get("token_count") self.meta = response.get("meta") self.is_search_required = response.get("is_search_required") # optional diff --git a/tests/sync/test_chat.py b/tests/sync/test_chat.py index 4818dc8e2..5dbd0f15b 100644 --- a/tests/sync/test_chat.py +++ b/tests/sync/test_chat.py @@ -40,17 +40,19 @@ def test_invalid_model(self): with self.assertRaises(cohere.CohereError): co.chat("Yo what up?", model="NOT_A_VALID_MODEL").text - def test_return_chatlog(self): - prediction = co.chat("Yo what up?", return_chatlog=True, max_tokens=5) + def test_return_chat_history(self): + prediction = co.chat("Yo what up?", return_chat_history=True, max_tokens=5) self.assertIsInstance(prediction.text, str) - self.assertIsNotNone(prediction.chatlog) - self.assertGreaterEqual(len(prediction.chatlog), len(prediction.text)) + self.assertIsNotNone(prediction.chat_history) + self.assertIsInstance(prediction.chat_history, list) + self.assertEqual(len(prediction.chat_history), 3) + self.assertIsInstance(prediction.chat_history[0], dict) - def test_return_chatlog_false(self): - prediction = co.chat("Yo what up?", return_chatlog=False, max_tokens=5) + def test_return_chat_history_false(self): + prediction = co.chat("Yo what up?", return_chat_history=False, max_tokens=5) self.assertIsInstance(prediction.text, str) - assert prediction.chatlog is None + assert prediction.chat_history is None def test_return_prompt(self): prediction = co.chat("Yo what up?", return_prompt=True, max_tokens=5) @@ -146,11 +148,11 @@ def test_chat_history(self): {"role": "Chatbot", "message": "Hey! How can I help you?"}, ], return_prompt=True, - return_chatlog=True, + return_chat_history=True, max_tokens=5, ) self.assertIsInstance(prediction.text, str) - self.assertIsNotNone(prediction.chatlog) + self.assertIsNotNone(prediction.chat_history) self.assertIn("User: Hey!", prediction.prompt) self.assertIn("Chatbot: Hey! How can I help you?", prediction.prompt) From 491bf5b403928ad96d8feb728edd47084123ae4d Mon Sep 17 00:00:00 2001 From: Angelique Ulep Date: Wed, 27 Sep 2023 22:03:18 -0400 Subject: [PATCH 2/4] Update chatlog to chat_history --- tests/async/test_async_chat.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/async/test_async_chat.py b/tests/async/test_async_chat.py index a37147482..2efd6d50c 100644 --- a/tests/async/test_async_chat.py +++ b/tests/async/test_async_chat.py @@ -17,14 +17,14 @@ async def test_async_multi_replies(async_client): conversation_id = f"test_conv_{conftest.random_word()}" num_replies = 3 prediction = await async_client.chat( - "Yo what's up?", return_chatlog=True, max_tokens=5, conversation_id=conversation_id + "Yo what's up?", return_chat_history=True, max_tokens=5, conversation_id=conversation_id ) - assert prediction.chatlog is not None + assert prediction.chat_history is not None for _ in range(num_replies): prediction = await prediction.respond("oh that's cool", max_tokens=5) assert isinstance(prediction.text, str) assert isinstance(prediction.conversation_id, str) - assert prediction.chatlog is not None + assert prediction.chat_history is not None assert prediction.meta assert prediction.meta["api_version"] assert prediction.meta["api_version"]["version"] From de7b4ea67637eec199714bc319fad64b7a641292 Mon Sep 17 00:00:00 2001 From: daniel Date: Thu, 28 Sep 2023 09:35:44 -0400 Subject: [PATCH 3/4] precommit --- cohere/responses/chat.py | 89 ++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 44 deletions(-) diff --git a/cohere/responses/chat.py b/cohere/responses/chat.py index 5dfe155c0..1359fbd4e 100644 --- a/cohere/responses/chat.py +++ b/cohere/responses/chat.py @@ -1,31 +1,32 @@ import json -import requests from enum import Enum from typing import Any, Dict, Generator, List, Optional, Union +import requests + from cohere.responses.base import CohereObject class Chat(CohereObject): def __init__( - self, - response_id: str, - generation_id: str, - message: str, - text: str, - conversation_id: str, - meta: Optional[Dict[str, Any]] = None, - prompt: Optional[str] = None, - chat_history: Optional[List[Dict[str, Any]]] = None, - preamble: Optional[str] = None, - token_count: Optional[Dict[str, int]] = None, - is_search_required: Optional[bool] = None, - citations: Optional[List[Dict[str, Any]]] = None, - documents: Optional[List[Dict[str, Any]]] = None, - search_results: Optional[List[Dict[str, Any]]] = None, - search_queries: Optional[List[Dict[str, Any]]] = None, - client=None, - **kwargs, + self, + response_id: str, + generation_id: str, + message: str, + text: str, + conversation_id: str, + meta: Optional[Dict[str, Any]] = None, + prompt: Optional[str] = None, + chat_history: Optional[List[Dict[str, Any]]] = None, + preamble: Optional[str] = None, + token_count: Optional[Dict[str, int]] = None, + is_search_required: Optional[bool] = None, + citations: Optional[List[Dict[str, Any]]] = None, + documents: Optional[List[Dict[str, Any]]] = None, + search_results: Optional[List[Dict[str, Any]]] = None, + search_queries: Optional[List[Dict[str, Any]]] = None, + client=None, + **kwargs, ) -> None: super().__init__(**kwargs) self.response_id = response_id @@ -101,11 +102,11 @@ class StreamEvent(str, Enum): class StreamResponse(CohereObject): def __init__( - self, - is_finished: bool, - event_type: Union[StreamEvent, str], - index: Optional[int], - **kwargs, + self, + is_finished: bool, + event_type: Union[StreamEvent, str], + index: Optional[int], + **kwargs, ) -> None: super().__init__(**kwargs) self.is_finished = is_finished @@ -115,10 +116,10 @@ def __init__( class StreamStart(StreamResponse): def __init__( - self, - generation_id: str, - conversation_id: Optional[str], - **kwargs, + self, + generation_id: str, + conversation_id: Optional[str], + **kwargs, ) -> None: super().__init__(**kwargs) self.generation_id = generation_id @@ -127,9 +128,9 @@ def __init__( class StreamTextGeneration(StreamResponse): def __init__( - self, - text: str, - **kwargs, + self, + text: str, + **kwargs, ) -> None: super().__init__(**kwargs) self.text = text @@ -137,9 +138,9 @@ def __init__( class StreamCitationGeneration(StreamResponse): def __init__( - self, - citations: Optional[List[Dict[str, Any]]], - **kwargs, + self, + citations: Optional[List[Dict[str, Any]]], + **kwargs, ) -> None: super().__init__(**kwargs) self.citations = citations @@ -147,9 +148,9 @@ def __init__( class StreamQueryGeneration(StreamResponse): def __init__( - self, - search_queries: Optional[List[Dict[str, Any]]], - **kwargs, + self, + search_queries: Optional[List[Dict[str, Any]]], + **kwargs, ) -> None: super().__init__(**kwargs) self.search_queries = search_queries @@ -157,10 +158,10 @@ def __init__( class StreamSearchResults(StreamResponse): def __init__( - self, - search_results: Optional[List[Dict[str, Any]]], - documents: Optional[List[Dict[str, Any]]], - **kwargs, + self, + search_results: Optional[List[Dict[str, Any]]], + documents: Optional[List[Dict[str, Any]]], + **kwargs, ) -> None: super().__init__(**kwargs) self.search_results = search_results @@ -169,9 +170,9 @@ def __init__( class StreamEnd(StreamResponse): def __init__( - self, - finish_reason: str, - **kwargs, + self, + finish_reason: str, + **kwargs, ) -> None: super().__init__(**kwargs) self.finish_reason = finish_reason From 0b4bf227d583af73bbb5c0f19cc517a423dc3a04 Mon Sep 17 00:00:00 2001 From: daniel Date: Thu, 28 Sep 2023 10:10:20 -0400 Subject: [PATCH 4/4] fix test --- tests/sync/test_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sync/test_chat.py b/tests/sync/test_chat.py index 5dbd0f15b..d5331e76f 100644 --- a/tests/sync/test_chat.py +++ b/tests/sync/test_chat.py @@ -45,7 +45,7 @@ def test_return_chat_history(self): self.assertIsInstance(prediction.text, str) self.assertIsNotNone(prediction.chat_history) self.assertIsInstance(prediction.chat_history, list) - self.assertEqual(len(prediction.chat_history), 3) + self.assertEqual(len(prediction.chat_history), 2) self.assertIsInstance(prediction.chat_history[0], dict) def test_return_chat_history_false(self):