From 15fffc74f57032cf363a5471182b70fb729bdcde Mon Sep 17 00:00:00 2001 From: alex-matton Date: Fri, 13 Oct 2023 17:11:06 -0400 Subject: [PATCH 1/6] add support for multilabel --- cohere/client.py | 10 +++++++++- cohere/client_async.py | 10 +++++++++- cohere/responses/classify.py | 29 ++++++++++++++++++++++------- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/cohere/client.py b/cohere/client.py index e389d5e18..f4eb9a7c0 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -489,7 +489,15 @@ def classify( for label, prediction in res["labels"].items(): labelObj[label] = LabelPrediction(prediction["confidence"]) classifications.append( - Classification(res["input"], res["prediction"], res["confidence"], labelObj, id=res["id"]) + Classification( + input=res["input"], + predictions=res.get("predictions", None), + confidences=res.get("confidences", None), + prediction=res.get("prediction", None), + confidence=res.get("confidence", None), + labels=labelObj, + id=res["id"], + ) ) return Classifications(classifications, response.get("meta")) diff --git a/cohere/client_async.py b/cohere/client_async.py index 25c895422..5dab2edc2 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -347,7 +347,15 @@ async def classify( for label, prediction in res["labels"].items(): labelObj[label] = LabelPrediction(prediction["confidence"]) classifications.append( - Classification(res["input"], res["prediction"], res["confidence"], labelObj, id=res["id"]) + Classification( + input=res["input"], + predictions=res.get("predictions", None), + confidences=res.get("confidences", None), + prediction=res.get("prediction", None), + confidence=res.get("confidence", None), + labels=labelObj, + id=res["id"], + ) ) return Classifications(classifications, response["meta"]) diff --git a/cohere/responses/classify.py b/cohere/responses/classify.py index 5c13ec2b9..25c2f9713 100644 --- a/cohere/responses/classify.py +++ b/cohere/responses/classify.py @@ -8,19 +8,34 @@ class Classification(CohereObject): def __init__( - self, input: str, prediction: str, confidence: float, labels: Dict[str, LabelPrediction], *args, **kwargs + self, + input: str, + predictions: Optional[List[str]], + confidences: Optional[List[float]], + prediction: Optional[str], + confidence: Optional[float], + labels: Dict[str, LabelPrediction], + *args, + **kwargs, ) -> None: super().__init__(*args, **kwargs) self.input = input - self.prediction = prediction - self.confidence = confidence + self.prediction = prediction # to be deprecated + self.confidence = confidence # to be deprecated + self.predictions = predictions + self.confidences = confidences self.labels = labels def __repr__(self) -> str: - prediction = self.prediction - confidence = self.confidence - labels = self.labels - return f'Classification' + if self.prediction is not None: + assert self.confidence is not None + return ( + f'Classification' + ) + else: + assert self.predictions is not None + assert self.confidences is not None + return f'Classification' class Classifications(CohereObject): From 8cd63cccaee3138784e9af09a318a6a86dc5f6a0 Mon Sep 17 00:00:00 2001 From: alex-matton Date: Mon, 16 Oct 2023 14:39:26 -0400 Subject: [PATCH 2/6] address comments --- CHANGELOG.md | 6 ++ cohere/client.py | 1 + cohere/client_async.py | 1 + cohere/responses/classify.py | 34 ++++++++--- pyproject.toml | 2 +- tests/test_classify_format.py | 104 ++++++++++++++++++++++++++++++++++ 6 files changed, 138 insertions(+), 10 deletions(-) create mode 100644 tests/test_classify_format.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e9075c290..d08ded8da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 4.31 + - [#324] (https://github.com/cohere-ai/cohere-python/pull/324) + - Classify: + - Deprecate `prediction` and `confidence` attribute + - Add new `predictions` and `confidences` attribute for single and multi label classification + ## 4.30 - [#313] (https://github.com/cohere-ai/cohere-python/pull/313) - change chatlog (string) to chat_history (array of messages) in /chat diff --git a/cohere/client.py b/cohere/client.py index f4eb9a7c0..266b250ed 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -496,6 +496,7 @@ def classify( prediction=res.get("prediction", None), confidence=res.get("confidence", None), labels=labelObj, + classification_type=res.get("classification_type", "single-label"), id=res["id"], ) ) diff --git a/cohere/client_async.py b/cohere/client_async.py index 5dab2edc2..bbc541c71 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -354,6 +354,7 @@ async def classify( prediction=res.get("prediction", None), confidence=res.get("confidence", None), labels=labelObj, + classification_type=res.get("classification_type", "single-label"), id=res["id"], ) ) diff --git a/cohere/responses/classify.py b/cohere/responses/classify.py index 25c2f9713..4cb84aa1c 100644 --- a/cohere/responses/classify.py +++ b/cohere/responses/classify.py @@ -15,28 +15,44 @@ def __init__( prediction: Optional[str], confidence: Optional[float], labels: Dict[str, LabelPrediction], + classification_type: str, *args, **kwargs, ) -> None: super().__init__(*args, **kwargs) self.input = input - self.prediction = prediction # to be deprecated - self.confidence = confidence # to be deprecated + self._prediction = prediction # to be deprecated + self._confidence = confidence # to be deprecated self.predictions = predictions self.confidences = confidences self.labels = labels + self.classification_type = classification_type + + if self._prediction is None or self._confidence is None: + if self._prediction is not None or self._confidence is not None: + raise ValueError("Cannot have one of prediction and confidence be None and not the other one") + if self.predictions is None or self.confidences is None: + raise ValueError("Cannot have predictions or confidences be None if prediction is None") def __repr__(self) -> str: - if self.prediction is not None: - assert self.confidence is not None - return ( - f'Classification' - ) + if self._prediction is not None: + return f'Classification' else: - assert self.predictions is not None - assert self.confidences is not None return f'Classification' + @property + def prediction(self): + print("`prediction` is deprecated and will be removed soon. Please use `predictions` instead.") + return self._prediction + + @property + def confidence(self): + print("`confidence` is deprecated and will be removed soon. Please use `confidences` instead.") + return self._confidence + + def is_multilabel(self) -> bool: + return self.classification_type == "multi-label" + class Classifications(CohereObject): def __init__(self, classifications: List[Classification], meta: Optional[Dict[str, Any]] = None) -> None: diff --git a/pyproject.toml b/pyproject.toml index 9ee9d4917..dfbec6d7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cohere" -version = "4.30" +version = "4.31" description = "" authors = ["Cohere"] readme = "README.md" diff --git a/tests/test_classify_format.py b/tests/test_classify_format.py new file mode 100644 index 000000000..576030201 --- /dev/null +++ b/tests/test_classify_format.py @@ -0,0 +1,104 @@ +from cohere import Client + + +def test_classifcation_old_single_label_format(mocker): + response = { + "id": "8a2c7187-6c01-41c0-a241-c064ad9618a5", + "classifications": [ + { + "classification_type": "single-label", + "confidence": 0.24627389, + "confidences": [0.24627389], + "id": "d0dfe4ce-525d-4530-ab26-ded93a101116", + "input": "I don't like this movie", + "labels": { + "negative": {"confidence": 0.24627389}, + "neutral": {"confidence": 0.18561405}, + "positive": {"confidence": 0.1925146}, + "very negative": {"confidence": 0.20908539}, + "very positive": {"confidence": 0.16651207}, + }, + "prediction": "negative", + "predictions": ["negative"], + }, + ], + "meta": {"api_version": {"version": "1"}}, + } + mocker.patch("cohere.Client._request", return_value=response) + co = Client("test_token") + result = co.classify(["i don't like this movie"], model="sentence classifier single label old") + # Both deprecated fields (prediction/confidence) and new fields (predictions/confidences) are supported + assert result[0].predictions == ["negative"] + assert result[0].confidences == [0.24627389] + assert result[0].prediction == "negative" + assert result[0].confidence == 0.24627389 + assert not result[0].is_multilabel() + + +def test_classify_new_single_label_format(mocker): + response = { + "id": "e994e80f-08b1-402f-8653-ced25a946f3a", + "classifications": [ + { + "classification_type": "single-label", + "confidence": 0.8908454, + "confidences": [0.8908454], + "id": "b9823024-3ad1-47d5-aed9-2bc4cb7775c8", + "input": "i love this movie!", + "labels": { + "negative": {"confidence": 7.224075e-05}, + "neutral": {"confidence": 0.0011411251}, + "positive": {"confidence": 0.10786094}, + "very negative": {"confidence": 8.027619e-05}, + "very positive": {"confidence": 0.8908454}, + }, + "prediction": "very positive", + "predictions": ["very positive"], + }, + ], + "meta": {"api_version": {"version": "1"}}, + } + mocker.patch("cohere.Client._request", return_value=response) + co = Client("test_token") + result = co.classify(["i love this movie!"], model="sentence classifier single label new") + # Both deprecated fields (prediction/confidence) and new fields (predictions/confidences) are supported + assert result[0].predictions == ["very positive"] + assert result[0].confidences == [0.8908454] + assert result[0].prediction == "very positive" + assert result[0].confidence == 0.8908454 + assert not result[0].is_multilabel() + + +def test_classify_multilabel_format(mocker): + response = { + "id": "cee2e2c2-83be-4c99-ad46-288448000b3f", + "classifications": [ + { + "classification_type": "multi-label", + "confidences": [0.6740505], + "id": "ff5b50c5-3f07-4993-9345-d47d71736164", + "input": "i love this movie!", + "labels": { + "0": {"confidence": 0.005260852}, + "1": {"confidence": 0.0029810327}, + "2": {"confidence": 0.000119598575}, + "3": {"confidence": 5.507606e-06}, + "4": {"confidence": 0.00055277866}, + "5": {"confidence": 0.00054847926}, + "6": {"confidence": 0.6740505}, + "7": {"confidence": 0.017242778}, + "8": {"confidence": 0.00026323833}, + "9": {"confidence": 0.00012533751}, + }, + "predictions": ["6"], + }, + ], + "meta": {"api_version": {"version": "1"}}, + } + mocker.patch("cohere.Client._request", return_value=response) + co = Client("test_token") + result = co.classify(["i love this movie!"], model="sentence classifier multi label new") + # prediction/confidence do not make sense for multi-label classification + assert result[0].predictions == ["6"] + assert result[0].confidences == [0.6740505] + assert result[0].is_multilabel() From 3ebe2ea22069ee17af189fd892347474510832ac Mon Sep 17 00:00:00 2001 From: alex-matton Date: Mon, 16 Oct 2023 14:46:16 -0400 Subject: [PATCH 3/6] replace mocker with monkeypatch --- tests/test_classify_format.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_classify_format.py b/tests/test_classify_format.py index 576030201..6c131e83e 100644 --- a/tests/test_classify_format.py +++ b/tests/test_classify_format.py @@ -1,7 +1,7 @@ from cohere import Client -def test_classifcation_old_single_label_format(mocker): +def test_classifcation_old_single_label_format(monkeypatch): response = { "id": "8a2c7187-6c01-41c0-a241-c064ad9618a5", "classifications": [ @@ -24,7 +24,7 @@ def test_classifcation_old_single_label_format(mocker): ], "meta": {"api_version": {"version": "1"}}, } - mocker.patch("cohere.Client._request", return_value=response) + monkeypatch.setattr("cohere.Client._request", lambda *args, **kwargs: response) co = Client("test_token") result = co.classify(["i don't like this movie"], model="sentence classifier single label old") # Both deprecated fields (prediction/confidence) and new fields (predictions/confidences) are supported @@ -35,7 +35,7 @@ def test_classifcation_old_single_label_format(mocker): assert not result[0].is_multilabel() -def test_classify_new_single_label_format(mocker): +def test_classify_new_single_label_format(monkeypatch): response = { "id": "e994e80f-08b1-402f-8653-ced25a946f3a", "classifications": [ @@ -58,7 +58,7 @@ def test_classify_new_single_label_format(mocker): ], "meta": {"api_version": {"version": "1"}}, } - mocker.patch("cohere.Client._request", return_value=response) + monkeypatch.setattr("cohere.Client._request", lambda *args, **kwargs: response) co = Client("test_token") result = co.classify(["i love this movie!"], model="sentence classifier single label new") # Both deprecated fields (prediction/confidence) and new fields (predictions/confidences) are supported @@ -69,7 +69,7 @@ def test_classify_new_single_label_format(mocker): assert not result[0].is_multilabel() -def test_classify_multilabel_format(mocker): +def test_classify_multilabel_format(monkeypatch): response = { "id": "cee2e2c2-83be-4c99-ad46-288448000b3f", "classifications": [ @@ -95,7 +95,7 @@ def test_classify_multilabel_format(mocker): ], "meta": {"api_version": {"version": "1"}}, } - mocker.patch("cohere.Client._request", return_value=response) + monkeypatch.setattr("cohere.Client._request", lambda *args, **kwargs: response) co = Client("test_token") result = co.classify(["i love this movie!"], model="sentence classifier multi label new") # prediction/confidence do not make sense for multi-label classification From d793cf335c378f2c19234529d2003c1e945eb85e Mon Sep 17 00:00:00 2001 From: alex-matton Date: Tue, 17 Oct 2023 09:48:10 -0400 Subject: [PATCH 4/6] print -> log --- cohere/responses/classify.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cohere/responses/classify.py b/cohere/responses/classify.py index 4cb84aa1c..e9ac8b8fe 100644 --- a/cohere/responses/classify.py +++ b/cohere/responses/classify.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, NamedTuple, Optional +from cohere.logging import logger from cohere.responses.base import CohereObject LabelPrediction = NamedTuple("LabelPrediction", [("confidence", float)]) @@ -42,12 +43,12 @@ def __repr__(self) -> str: @property def prediction(self): - print("`prediction` is deprecated and will be removed soon. Please use `predictions` instead.") + logger.warning("`prediction` is deprecated and will be removed soon. Please use `predictions` instead.") return self._prediction @property def confidence(self): - print("`confidence` is deprecated and will be removed soon. Please use `confidences` instead.") + logger.warning("`confidence` is deprecated and will be removed soon. Please use `confidences` instead.") return self._confidence def is_multilabel(self) -> bool: From da18bbc93dec624ef3a2e24e9b3311aa00e39d2b Mon Sep 17 00:00:00 2001 From: alex-matton Date: Tue, 17 Oct 2023 17:07:30 -0400 Subject: [PATCH 5/6] address comment --- cohere/responses/classify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cohere/responses/classify.py b/cohere/responses/classify.py index e9ac8b8fe..839917145 100644 --- a/cohere/responses/classify.py +++ b/cohere/responses/classify.py @@ -31,9 +31,9 @@ def __init__( if self._prediction is None or self._confidence is None: if self._prediction is not None or self._confidence is not None: - raise ValueError("Cannot have one of prediction and confidence be None and not the other one") + raise ValueError("Cannot have one of `prediction` and `confidence` be None and not the other one") if self.predictions is None or self.confidences is None: - raise ValueError("Cannot have predictions or confidences be None if prediction is None") + raise ValueError("Cannot have `predictions` or `confidences` be None if `prediction` is None") def __repr__(self) -> str: if self._prediction is not None: From f6decb0222a55369e95f74cf1096ed3f767e73f6 Mon Sep 17 00:00:00 2001 From: alex-matton Date: Thu, 19 Oct 2023 09:39:48 -0400 Subject: [PATCH 6/6] remove ambiguity in comments --- cohere/responses/classify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cohere/responses/classify.py b/cohere/responses/classify.py index 839917145..5554ec06e 100644 --- a/cohere/responses/classify.py +++ b/cohere/responses/classify.py @@ -22,8 +22,8 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) self.input = input - self._prediction = prediction # to be deprecated - self._confidence = confidence # to be deprecated + self._prediction = prediction # to be removed + self._confidence = confidence # to be removed self.predictions = predictions self.confidences = confidences self.labels = labels