Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def chat(
message: Optional[str] = None,
conversation_id: Optional[str] = "",
model: Optional[str] = None,
return_chatlog: Optional[bool] = False,
return_chat_history: Optional[bool] = False,
return_prompt: Optional[bool] = False,
return_preamble: Optional[bool] = False,
chat_history: Optional[List[Dict[str, str]]] = None,
Expand Down Expand Up @@ -265,7 +265,7 @@ def chat(
logit_bias (Dict[int, float]): (Optional) A dictionary of logit bias values to use for the next reply.
max_tokens (int): (Optional) The max tokens generated for the next reply.

return_chatlog (bool): (Optional) Whether to return the chatlog.
return_chat_history (bool): (Optional) Whether to return the chat history.
return_prompt (bool): (Optional) Whether to return the prompt.
return_preamble (bool): (Optional) Whether to return the preamble.

Expand Down Expand Up @@ -302,9 +302,9 @@ def chat(
>>> message="Hey! How are you doing today?",
>>> conversation_id="1234",
>>> model="command",
>>> return_chatlog=True)
>>> return_chat_history=True)
>>> print(res.text)
>>> print(res.chatlog)
>>> print(res.chat_history)
Streaming chat:
>>> res = co.chat(
>>> message="Hey! How are you doing today?",
Expand Down Expand Up @@ -360,7 +360,7 @@ def chat(
"message": message,
"conversation_id": conversation_id,
"model": model,
"return_chatlog": return_chatlog,
"return_chat_history": return_chat_history,
"return_prompt": return_prompt,
"return_preamble": return_preamble,
"chat_history": chat_history,
Expand Down
4 changes: 2 additions & 2 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def chat(
message: Optional[str] = None,
conversation_id: Optional[str] = "",
model: Optional[str] = None,
return_chatlog: Optional[bool] = False,
return_chat_history: Optional[bool] = False,
return_prompt: Optional[bool] = False,
return_preamble: Optional[bool] = False,
chat_history: Optional[List[Dict[str, str]]] = None,
Expand All @@ -238,7 +238,7 @@ async def chat(
"message": message,
"conversation_id": conversation_id,
"model": model,
"return_chatlog": return_chatlog,
"return_chat_history": return_chat_history,
"return_prompt": return_prompt,
"return_preamble": return_preamble,
"chat_history": chat_history,
Expand Down
14 changes: 7 additions & 7 deletions cohere/responses/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
conversation_id: str,
meta: Optional[Dict[str, Any]] = None,
prompt: Optional[str] = None,
chatlog: Optional[List[Dict[str, str]]] = None,
chat_history: Optional[List[Dict[str, Any]]] = None,
preamble: Optional[str] = None,
token_count: Optional[Dict[str, int]] = None,
is_search_required: Optional[bool] = None,
Expand All @@ -35,7 +35,7 @@ def __init__(
self.text = text
self.conversation_id = conversation_id # optional
self.prompt = prompt # optional
self.chatlog = chatlog # optional
self.chat_history = chat_history # optional
self.preamble = preamble # optional
self.client = client
self.token_count = token_count
Expand All @@ -56,7 +56,7 @@ def from_dict(cls, response: Dict[str, Any], message: str, client) -> "Chat":
conversation_id=response.get("conversation_id"), # optional
text=response.get("text"),
prompt=response.get("prompt"), # optional
chatlog=response.get("chatlog"), # optional
chat_history=response.get("chat_history"), # optional
preamble=response.get("preamble"), # optional
client=client,
token_count=response.get("token_count"),
Expand All @@ -72,7 +72,7 @@ def respond(self, response: str, max_tokens: int = None) -> "Chat":
return self.client.chat(
message=response,
conversation_id=self.conversation_id,
return_chatlog=self.chatlog is not None,
return_chat_history=self.chat_history is not None,
return_prompt=self.prompt is not None,
return_preamble=self.preamble is not None,
max_tokens=max_tokens,
Expand All @@ -84,7 +84,7 @@ async def respond(self, response: str, max_tokens: int = None) -> "AsyncChat":
return await self.client.chat(
message=response,
conversation_id=self.conversation_id,
return_chatlog=self.chatlog is not None,
return_chat_history=self.chat_history is not None,
return_prompt=self.prompt is not None,
return_preamble=self.preamble is not None,
max_tokens=max_tokens,
Expand Down Expand Up @@ -187,7 +187,7 @@ def __init__(self, response):
self.generation_id = None
self.preamble = None
self.prompt = None
self.chatlog = None
self.chat_history = None
self.finish_reason = None
self.token_count = None
self.meta = None
Expand Down Expand Up @@ -246,7 +246,7 @@ def _make_response_item(self, index, line) -> Any:
self.generation_id = response.get("generation_id")
self.preamble = response.get("preamble")
self.prompt = response.get("prompt")
self.chatlog = response.get("chatlog")
self.chat_history = response.get("chat_history")
self.token_count = response.get("token_count")
self.meta = response.get("meta")
self.is_search_required = response.get("is_search_required") # optional
Expand Down
6 changes: 3 additions & 3 deletions tests/async/test_async_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ async def test_async_multi_replies(async_client):
conversation_id = f"test_conv_{conftest.random_word()}"
num_replies = 3
prediction = await async_client.chat(
"Yo what's up?", return_chatlog=True, max_tokens=5, conversation_id=conversation_id
"Yo what's up?", return_chat_history=True, max_tokens=5, conversation_id=conversation_id
)
assert prediction.chatlog is not None
assert prediction.chat_history is not None
for _ in range(num_replies):
prediction = await prediction.respond("oh that's cool", max_tokens=5)
assert isinstance(prediction.text, str)
assert isinstance(prediction.conversation_id, str)
assert prediction.chatlog is not None
assert prediction.chat_history is not None
assert prediction.meta
assert prediction.meta["api_version"]
assert prediction.meta["api_version"]["version"]
Expand Down
20 changes: 11 additions & 9 deletions tests/sync/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,19 @@ def test_invalid_model(self):
with self.assertRaises(cohere.CohereError):
co.chat("Yo what up?", model="NOT_A_VALID_MODEL").text

def test_return_chatlog(self):
prediction = co.chat("Yo what up?", return_chatlog=True, max_tokens=5)
def test_return_chat_history(self):
prediction = co.chat("Yo what up?", return_chat_history=True, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsNotNone(prediction.chatlog)
self.assertGreaterEqual(len(prediction.chatlog), len(prediction.text))
self.assertIsNotNone(prediction.chat_history)
self.assertIsInstance(prediction.chat_history, list)
self.assertEqual(len(prediction.chat_history), 2)
self.assertIsInstance(prediction.chat_history[0], dict)

def test_return_chatlog_false(self):
prediction = co.chat("Yo what up?", return_chatlog=False, max_tokens=5)
def test_return_chat_history_false(self):
prediction = co.chat("Yo what up?", return_chat_history=False, max_tokens=5)
self.assertIsInstance(prediction.text, str)

assert prediction.chatlog is None
assert prediction.chat_history is None

def test_return_prompt(self):
prediction = co.chat("Yo what up?", return_prompt=True, max_tokens=5)
Expand Down Expand Up @@ -146,11 +148,11 @@ def test_chat_history(self):
{"role": "Chatbot", "message": "Hey! How can I help you?"},
],
return_prompt=True,
return_chatlog=True,
return_chat_history=True,
max_tokens=5,
)
self.assertIsInstance(prediction.text, str)
self.assertIsNotNone(prediction.chatlog)
self.assertIsNotNone(prediction.chat_history)
self.assertIn("User: Hey!", prediction.prompt)
self.assertIn("Chatbot: Hey! How can I help you?", prediction.prompt)

Expand Down