Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## 4.21
- [#287] (https://github.com/cohere-ai/cohere-python/pull/287)
- Remove deprecated chat "query" parameter including inside chat_history parameter
- Support event-type for chat streaming

## 4.20.2
- [#284] (https://github.com/cohere-ai/cohere-python/pull/284)
- Rename dataset urls to download_urls
Expand Down
47 changes: 15 additions & 32 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def generate(
def chat(
self,
message: Optional[str] = None,
query: Optional[str] = None,
conversation_id: Optional[str] = "",
model: Optional[str] = None,
return_chatlog: Optional[bool] = False,
Expand All @@ -246,30 +245,33 @@ def chat(
"""Returns a Chat object with the query reply.

Args:
query (str): Deprecated. Use message instead.
message (str): The message to send to the chatbot.
conversation_id (str): (Optional) The conversation id to continue the conversation.
model (str): (Optional) The model to use for generating the next reply.
return_chatlog (bool): (Optional) Whether to return the chatlog.
return_prompt (bool): (Optional) Whether to return the prompt.
return_preamble (bool): (Optional) Whether to return the preamble.
chat_history (List[Dict[str, str]]): (Optional) A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state.
preamble_override (str): (Optional) A string to override the preamble.
user_name (str): (Optional) A string to override the username.
temperature (float): (Optional) The temperature to use for the next reply. The higher the temperature, the more random the reply.
max_tokens (int): (Optional) The max tokens generated for the next reply.

stream (bool): Return streaming tokens.
conversation_id (str): (Optional) To store a conversation then create a conversation id and use it for every related request.

preamble_override (str): (Optional) A string to override the preamble.
chat_history (List[Dict[str, str]]): (Optional) A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state.

model (str): (Optional) The model to use for generating the response.
temperature (float): (Optional) The temperature to use for the response. The higher the temperature, the more random the response.
p (float): (Optional) The nucleus sampling probability.
k (float): (Optional) The top-k sampling probability.
logit_bias (Dict[int, float]): (Optional) A dictionary of logit bias values to use for the next reply.
max_tokens (int): (Optional) The max tokens generated for the next reply.

return_chatlog (bool): (Optional) Whether to return the chatlog.
return_prompt (bool): (Optional) Whether to return the prompt.
return_preamble (bool): (Optional) Whether to return the preamble.

user_name (str): (Optional) A string to override the username.
Returns:
a Chat object if stream=False, or a StreamingChat object if stream=True

Examples:
A simple chat message:
>>> res = co.chat(message="Hey! How are you doing today?")
>>> print(res.text)
>>> print(res.conversation_id)
Continuing a session using a specific model:
>>> res = co.chat(
>>> message="Hey! How are you doing today?",
Expand All @@ -295,25 +297,6 @@ def chat(
>>> print(res.text)
>>> print(res.prompt)
"""
if chat_history is not None:
should_warn = True
for entry in chat_history:
if "text" in entry:
entry["message"] = entry["text"]

if "text" in entry and should_warn:
logger.warning(
"The 'text' parameter is deprecated and will be removed in a future version of this function. "
+ "Use 'message' instead.",
)
should_warn = False

if query is not None:
logger.warning(
"The chat_history 'text' key is deprecated and will be removed in a future version of this function. "
+ "Use 'message' instead.",
)
message = query

json_body = {
"message": message,
Expand Down
25 changes: 2 additions & 23 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ async def generate(
async def chat(
self,
message: Optional[str] = None,
query: Optional[str] = None,
conversation_id: Optional[str] = "",
model: Optional[str] = None,
return_chatlog: Optional[bool] = False,
Expand All @@ -225,28 +224,8 @@ async def chat(
k: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
) -> Union[AsyncChat, StreamingChat]:
if chat_history is not None:
should_warn = True
for entry in chat_history:
if "text" in entry:
entry["message"] = entry["text"]

if "text" in entry and should_warn:
logger.warning(
"The 'text' parameter is deprecated and will be removed in a future version of this function. "
+ "Use 'message' instead.",
)
should_warn = False

if query is None and message is None:
raise CohereError("Either 'query' or 'message' must be provided.")

if query is not None:
logger.warning(
"The 'query' parameter is deprecated and will be removed in a future version of this function. "
+ "Use 'message' instead.",
)
message = query
if message is None:
raise CohereError("'message' must be provided.")

json_body = {
"message": message,
Expand Down
91 changes: 67 additions & 24 deletions cohere/responses/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Generator, List, NamedTuple, Optional
from typing import Any, Dict, Generator, List, Optional

import requests

Expand All @@ -25,10 +25,9 @@ def __init__(
super().__init__(**kwargs)
self.response_id = response_id
self.generation_id = generation_id
self.query = message # to be deprecated
self.message = message
self.text = text
self.conversation_id = conversation_id
self.conversation_id = conversation_id # optional
self.prompt = prompt # optional
self.chatlog = chatlog # optional
self.preamble = preamble # optional
Expand All @@ -47,7 +46,7 @@ def from_dict(cls, response: Dict[str, Any], message: str, client) -> "Chat":
text=response.get("text"),
prompt=response.get("prompt"), # optional
chatlog=response.get("chatlog"), # optional
preamble=response.get("preamble"), # option
preamble=response.get("preamble"), # optional
client=client,
token_count=response.get("token_count"),
meta=response.get("meta"),
Expand Down Expand Up @@ -76,7 +75,38 @@ async def respond(self, response: str, max_tokens: int = None) -> "AsyncChat":
)


StreamingText = NamedTuple("StreamingText", [("index", Optional[int]), ("text", str), ("is_finished", bool)])
class StreamResponse(CohereObject):
def __init__(
self,
is_finished: bool,
index: Optional[int],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.is_finished = is_finished
self.index = index


class StreamStart(StreamResponse):
def __init__(
self,
generation_id: str,
conversation_id: Optional[str],
**kwargs,
) -> None:
super().__init__(**kwargs)
self.generation_id = generation_id
self.conversation_id = conversation_id


class StreamTextGeneration(StreamResponse):
def __init__(
self,
text: str,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.text = text


class StreamingChat(CohereObject):
Expand All @@ -85,34 +115,47 @@ def __init__(self, response):
self.texts = []
self.response_id = None
self.conversation_id = None
self.generation_id = None
self.preamble = None
self.prompt = None
self.chatlog = None
self.finish_reason = None
self.token_count = None
self.meta = None

def _make_response_item(self, index, line) -> Any:
streaming_item = json.loads(line)
is_finished = streaming_item.get("is_finished")
text = streaming_item.get("text")

if not is_finished:
return StreamingText(text=text, is_finished=is_finished, index=index)

response = streaming_item.get("response")

if response is None:
event_type = streaming_item.get("event_type")

if event_type == "stream-start":
self.conversation_id = streaming_item.get("conversation_id")
self.generation_id = streaming_item.get("generation_id")
return StreamStart(
conversation_id=self.conversation_id, generation_id=self.generation_id, is_finished=False, index=index
)
elif event_type == "text-generation":
text = streaming_item.get("text")
return StreamTextGeneration(text=text, is_finished=False, index=index)
elif event_type == "stream-end":
response = streaming_item.get("response")
self.finish_reason = streaming_item.get("finish_reason")

if response is None:
return None

self.response_id = response.get("response_id")
self.conversation_id = response.get("conversation_id")
self.texts = [response.get("text")]
self.generation_id = response.get("generation_id")
self.preamble = response.get("preamble")
self.prompt = response.get("prompt")
self.chatlog = response.get("chatlog")
self.token_count = response.get("token_count")
self.meta = response.get("meta")
return None

self.response_id = response.get("response_id")
self.conversation_id = response.get("conversation_id")
self.preamble = response.get("preamble")
self.prompt = response.get("prompt")
self.chatlog = response.get("chatlog")
self.finish_reason = streaming_item.get("finish_reason")
self.texts = [response.get("text")]
return None

def __iter__(self) -> Generator[StreamingText, None, None]:
def __iter__(self) -> Generator[StreamResponse, None, None]:
if not isinstance(self.response, requests.Response):
raise ValueError("For AsyncClient, use `async for` to iterate through the `StreamingChat`")

Expand All @@ -121,7 +164,7 @@ def __iter__(self) -> Generator[StreamingText, None, None]:
if item is not None:
yield item

async def __aiter__(self) -> Generator[StreamingText, None, None]:
async def __aiter__(self) -> Generator[StreamResponse, None, None]:
index = 0
async for line in self.response.content:
item = self._make_response_item(index, line)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cohere"
version = "4.20.2"
version = "4.21"
description = ""
authors = ["Cohere"]
readme = "README.md"
Expand Down
13 changes: 10 additions & 3 deletions tests/async/test_async_chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

import cohere


@pytest.mark.asyncio
async def test_async_multi_replies(async_client):
Expand Down Expand Up @@ -36,12 +38,17 @@ async def test_async_chat_stream(async_client):
expected_index = 0
expected_text = ""
async for token in res:
if token.text:
if isinstance(token, cohere.responses.chat.StreamStart):
assert token.generation_id is not None
assert not token.is_finished
elif isinstance(token, cohere.responses.chat.StreamTextGeneration):
assert isinstance(token.text, str)
assert len(token.text) > 0
assert token.index == expected_index

expected_text += token.text
assert not token.is_finished

assert isinstance(token.index, int)
assert token.index == expected_index
expected_index += 1

assert res.texts == [expected_text]
Expand Down
12 changes: 7 additions & 5 deletions tests/sync/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,16 @@ def test_stream(self):
expected_index = 0
expected_text = ""
for token in prediction:
if token.text:
if isinstance(token, cohere.responses.chat.StreamStart):
self.assertIsNotNone(token.generation_id)
self.assertFalse(token.is_finished)
elif isinstance(token, cohere.responses.chat.StreamTextGeneration):
self.assertIsInstance(token.text, str)
self.assertGreater(len(token.text), 0)

self.assertIsInstance(token.index, int)
self.assertEqual(token.index, expected_index)

expected_text += token.text
self.assertFalse(token.is_finished)
self.assertIsInstance(token.index, int)
self.assertEqual(token.index, expected_index)
expected_index += 1

self.assertEqual(prediction.texts, [expected_text])
Expand Down