diff --git a/CHANGELOG.md b/CHANGELOG.md index d90abbd31..15255e0fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 4.28 + - [#310] (https://github.com/cohere-ai/cohere-python/pull/310) + - Embed: add input_type parameter for new embed models + ## 4.27 - [#308] (https://github.com/cohere-ai/cohere-python/pull/308) - Datasets: add validation_warnings diff --git a/cohere/client.py b/cohere/client.py index db5261a56..510d8810c 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -395,6 +395,7 @@ def embed( truncate: Optional[str] = None, compress: Optional[bool] = False, compression_codebook: Optional[str] = "default", + input_type: Optional[str] = None, ) -> Embeddings: """Returns an Embeddings object for the provided texts. Visit https://cohere.ai/embed to learn about embeddings. @@ -404,6 +405,7 @@ def embed( truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length. compress (bool): (Optional) Whether to compress the embeddings. When True, the compressed_embeddings will be returned as integers in the range [0, 255]. compression_codebook (str): (Optional) The compression codebook to use for compressed embeddings. Defaults to "default". + input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed. """ responses = { "embeddings": [], @@ -420,6 +422,7 @@ def embed( "truncate": truncate, "compress": compress, "compression_codebook": compression_codebook, + "input_type": input_type, } ) diff --git a/cohere/client_async.py b/cohere/client_async.py index 56e9f2b43..9d952af75 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -273,6 +273,7 @@ async def embed( truncate: Optional[str] = None, compress: Optional[bool] = False, compression_codebook: Optional[str] = "default", + input_type: Optional[str] = None, ) -> Embeddings: """Returns an Embeddings object for the provided texts. Visit https://cohere.ai/embed to learn about embeddings. @@ -282,6 +283,7 @@ async def embed( truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length. compress (bool): (Optional) Whether to compress the embeddings. When True, the compressed_embeddings will be returned as integers in the range [0, 255]. compression_codebook (str): (Optional) The compression codebook to use for compressed embeddings. Defaults to "default". + input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed. """ json_bodys = [ dict( @@ -290,6 +292,7 @@ async def embed( truncate=truncate, compress=compress, compression_codebook=compression_codebook, + input_type=input_type, ) for i in range(0, len(texts), cohere.COHERE_EMBED_BATCH_SIZE) ] diff --git a/tests/sync/test_embed.py b/tests/sync/test_embed.py index da85c75c5..4dd625e43 100644 --- a/tests/sync/test_embed.py +++ b/tests/sync/test_embed.py @@ -2,9 +2,11 @@ import string import unittest +import pytest from utils import get_api_key import cohere +from cohere.error import CohereError API_KEY = get_api_key() co = cohere.Client(API_KEY) @@ -102,3 +104,24 @@ def test_success_multiple_batches_in_order(self): for predictionExpected, predictionActual in zip(predictionsExpected, list(predictionsActual)): for elementExpected, elementAcutal in zip(predictionExpected, predictionActual): self.assertAlmostEqual(elementExpected, elementAcutal, places=1) + + def test_fail_with_new_model_no_input_type(self): + text_batch = random_texts(cohere.COHERE_EMBED_BATCH_SIZE) + with pytest.raises(CohereError): + co.embed(model="embed-english-v3.0", texts=text_batch) + + def test_fail_with_new_model_invalid_input_type(self): + text_batch = random_texts(cohere.COHERE_EMBED_BATCH_SIZE) + input_type = "invalid" + with pytest.raises(CohereError): + co.embed(model="embed-english-v3.0", texts=text_batch, input_type=input_type) + + def test_success_with_new_model_and_input_type(self): + text = ["cohere"] + input_types = ["classification", "search_document", "search_query", "clustering"] + + for input_type in input_types: + prediction = co.embed(model="embed-english-v3.0", texts=text, input_type=input_type) + embed = prediction.embeddings[0] + self.assertIsInstance(embed, list) + self.assertEqual(len(embed), 1024)