diff --git a/CHANGELOG.md b/CHANGELOG.md index e9075c290..d08ded8da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 4.31 + - [#324] (https://github.com/cohere-ai/cohere-python/pull/324) + - Classify: + - Deprecate `prediction` and `confidence` attribute + - Add new `predictions` and `confidences` attribute for single and multi label classification + ## 4.30 - [#313] (https://github.com/cohere-ai/cohere-python/pull/313) - change chatlog (string) to chat_history (array of messages) in /chat diff --git a/cohere/client.py b/cohere/client.py index e389d5e18..266b250ed 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -489,7 +489,16 @@ def classify( for label, prediction in res["labels"].items(): labelObj[label] = LabelPrediction(prediction["confidence"]) classifications.append( - Classification(res["input"], res["prediction"], res["confidence"], labelObj, id=res["id"]) + Classification( + input=res["input"], + predictions=res.get("predictions", None), + confidences=res.get("confidences", None), + prediction=res.get("prediction", None), + confidence=res.get("confidence", None), + labels=labelObj, + classification_type=res.get("classification_type", "single-label"), + id=res["id"], + ) ) return Classifications(classifications, response.get("meta")) diff --git a/cohere/client_async.py b/cohere/client_async.py index 25c895422..bbc541c71 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -347,7 +347,16 @@ async def classify( for label, prediction in res["labels"].items(): labelObj[label] = LabelPrediction(prediction["confidence"]) classifications.append( - Classification(res["input"], res["prediction"], res["confidence"], labelObj, id=res["id"]) + Classification( + input=res["input"], + predictions=res.get("predictions", None), + confidences=res.get("confidences", None), + prediction=res.get("prediction", None), + confidence=res.get("confidence", None), + labels=labelObj, + classification_type=res.get("classification_type", "single-label"), + id=res["id"], + ) ) return Classifications(classifications, response["meta"]) diff --git a/cohere/responses/classify.py b/cohere/responses/classify.py index 5c13ec2b9..5554ec06e 100644 --- a/cohere/responses/classify.py +++ b/cohere/responses/classify.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, NamedTuple, Optional +from cohere.logging import logger from cohere.responses.base import CohereObject LabelPrediction = NamedTuple("LabelPrediction", [("confidence", float)]) @@ -8,19 +9,50 @@ class Classification(CohereObject): def __init__( - self, input: str, prediction: str, confidence: float, labels: Dict[str, LabelPrediction], *args, **kwargs + self, + input: str, + predictions: Optional[List[str]], + confidences: Optional[List[float]], + prediction: Optional[str], + confidence: Optional[float], + labels: Dict[str, LabelPrediction], + classification_type: str, + *args, + **kwargs, ) -> None: super().__init__(*args, **kwargs) self.input = input - self.prediction = prediction - self.confidence = confidence + self._prediction = prediction # to be removed + self._confidence = confidence # to be removed + self.predictions = predictions + self.confidences = confidences self.labels = labels + self.classification_type = classification_type + + if self._prediction is None or self._confidence is None: + if self._prediction is not None or self._confidence is not None: + raise ValueError("Cannot have one of `prediction` and `confidence` be None and not the other one") + if self.predictions is None or self.confidences is None: + raise ValueError("Cannot have `predictions` or `confidences` be None if `prediction` is None") def __repr__(self) -> str: - prediction = self.prediction - confidence = self.confidence - labels = self.labels - return f'Classification' + if self._prediction is not None: + return f'Classification' + else: + return f'Classification' + + @property + def prediction(self): + logger.warning("`prediction` is deprecated and will be removed soon. Please use `predictions` instead.") + return self._prediction + + @property + def confidence(self): + logger.warning("`confidence` is deprecated and will be removed soon. Please use `confidences` instead.") + return self._confidence + + def is_multilabel(self) -> bool: + return self.classification_type == "multi-label" class Classifications(CohereObject): diff --git a/pyproject.toml b/pyproject.toml index 9ee9d4917..dfbec6d7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cohere" -version = "4.30" +version = "4.31" description = "" authors = ["Cohere"] readme = "README.md" diff --git a/tests/test_classify_format.py b/tests/test_classify_format.py new file mode 100644 index 000000000..6c131e83e --- /dev/null +++ b/tests/test_classify_format.py @@ -0,0 +1,104 @@ +from cohere import Client + + +def test_classifcation_old_single_label_format(monkeypatch): + response = { + "id": "8a2c7187-6c01-41c0-a241-c064ad9618a5", + "classifications": [ + { + "classification_type": "single-label", + "confidence": 0.24627389, + "confidences": [0.24627389], + "id": "d0dfe4ce-525d-4530-ab26-ded93a101116", + "input": "I don't like this movie", + "labels": { + "negative": {"confidence": 0.24627389}, + "neutral": {"confidence": 0.18561405}, + "positive": {"confidence": 0.1925146}, + "very negative": {"confidence": 0.20908539}, + "very positive": {"confidence": 0.16651207}, + }, + "prediction": "negative", + "predictions": ["negative"], + }, + ], + "meta": {"api_version": {"version": "1"}}, + } + monkeypatch.setattr("cohere.Client._request", lambda *args, **kwargs: response) + co = Client("test_token") + result = co.classify(["i don't like this movie"], model="sentence classifier single label old") + # Both deprecated fields (prediction/confidence) and new fields (predictions/confidences) are supported + assert result[0].predictions == ["negative"] + assert result[0].confidences == [0.24627389] + assert result[0].prediction == "negative" + assert result[0].confidence == 0.24627389 + assert not result[0].is_multilabel() + + +def test_classify_new_single_label_format(monkeypatch): + response = { + "id": "e994e80f-08b1-402f-8653-ced25a946f3a", + "classifications": [ + { + "classification_type": "single-label", + "confidence": 0.8908454, + "confidences": [0.8908454], + "id": "b9823024-3ad1-47d5-aed9-2bc4cb7775c8", + "input": "i love this movie!", + "labels": { + "negative": {"confidence": 7.224075e-05}, + "neutral": {"confidence": 0.0011411251}, + "positive": {"confidence": 0.10786094}, + "very negative": {"confidence": 8.027619e-05}, + "very positive": {"confidence": 0.8908454}, + }, + "prediction": "very positive", + "predictions": ["very positive"], + }, + ], + "meta": {"api_version": {"version": "1"}}, + } + monkeypatch.setattr("cohere.Client._request", lambda *args, **kwargs: response) + co = Client("test_token") + result = co.classify(["i love this movie!"], model="sentence classifier single label new") + # Both deprecated fields (prediction/confidence) and new fields (predictions/confidences) are supported + assert result[0].predictions == ["very positive"] + assert result[0].confidences == [0.8908454] + assert result[0].prediction == "very positive" + assert result[0].confidence == 0.8908454 + assert not result[0].is_multilabel() + + +def test_classify_multilabel_format(monkeypatch): + response = { + "id": "cee2e2c2-83be-4c99-ad46-288448000b3f", + "classifications": [ + { + "classification_type": "multi-label", + "confidences": [0.6740505], + "id": "ff5b50c5-3f07-4993-9345-d47d71736164", + "input": "i love this movie!", + "labels": { + "0": {"confidence": 0.005260852}, + "1": {"confidence": 0.0029810327}, + "2": {"confidence": 0.000119598575}, + "3": {"confidence": 5.507606e-06}, + "4": {"confidence": 0.00055277866}, + "5": {"confidence": 0.00054847926}, + "6": {"confidence": 0.6740505}, + "7": {"confidence": 0.017242778}, + "8": {"confidence": 0.00026323833}, + "9": {"confidence": 0.00012533751}, + }, + "predictions": ["6"], + }, + ], + "meta": {"api_version": {"version": "1"}}, + } + monkeypatch.setattr("cohere.Client._request", lambda *args, **kwargs: response) + co = Client("test_token") + result = co.classify(["i love this movie!"], model="sentence classifier multi label new") + # prediction/confidence do not make sense for multi-label classification + assert result[0].predictions == ["6"] + assert result[0].confidences == [0.6740505] + assert result[0].is_multilabel()