From 352d9f582188cd6a8b4ced0ea477eebeb2d0cbd8 Mon Sep 17 00:00:00 2001 From: Alekhya Date: Wed, 20 Sep 2023 10:09:56 -0700 Subject: [PATCH 1/3] add input_type to embed --- cohere/client.py | 3 +++ cohere/client_async.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/cohere/client.py b/cohere/client.py index db5261a56..510d8810c 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -395,6 +395,7 @@ def embed( truncate: Optional[str] = None, compress: Optional[bool] = False, compression_codebook: Optional[str] = "default", + input_type: Optional[str] = None, ) -> Embeddings: """Returns an Embeddings object for the provided texts. Visit https://cohere.ai/embed to learn about embeddings. @@ -404,6 +405,7 @@ def embed( truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length. compress (bool): (Optional) Whether to compress the embeddings. When True, the compressed_embeddings will be returned as integers in the range [0, 255]. compression_codebook (str): (Optional) The compression codebook to use for compressed embeddings. Defaults to "default". + input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed. """ responses = { "embeddings": [], @@ -420,6 +422,7 @@ def embed( "truncate": truncate, "compress": compress, "compression_codebook": compression_codebook, + "input_type": input_type, } ) diff --git a/cohere/client_async.py b/cohere/client_async.py index 56e9f2b43..9d952af75 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -273,6 +273,7 @@ async def embed( truncate: Optional[str] = None, compress: Optional[bool] = False, compression_codebook: Optional[str] = "default", + input_type: Optional[str] = None, ) -> Embeddings: """Returns an Embeddings object for the provided texts. Visit https://cohere.ai/embed to learn about embeddings. @@ -282,6 +283,7 @@ async def embed( truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length. compress (bool): (Optional) Whether to compress the embeddings. When True, the compressed_embeddings will be returned as integers in the range [0, 255]. compression_codebook (str): (Optional) The compression codebook to use for compressed embeddings. Defaults to "default". + input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed. """ json_bodys = [ dict( @@ -290,6 +292,7 @@ async def embed( truncate=truncate, compress=compress, compression_codebook=compression_codebook, + input_type=input_type, ) for i in range(0, len(texts), cohere.COHERE_EMBED_BATCH_SIZE) ] From 352b83eb39debaa498cb59a06be8246eb74a0556 Mon Sep 17 00:00:00 2001 From: Alekhya Date: Wed, 20 Sep 2023 23:05:12 -0700 Subject: [PATCH 2/3] added tests for input_type --- tests/sync/test_embed.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/sync/test_embed.py b/tests/sync/test_embed.py index da85c75c5..4dd625e43 100644 --- a/tests/sync/test_embed.py +++ b/tests/sync/test_embed.py @@ -2,9 +2,11 @@ import string import unittest +import pytest from utils import get_api_key import cohere +from cohere.error import CohereError API_KEY = get_api_key() co = cohere.Client(API_KEY) @@ -102,3 +104,24 @@ def test_success_multiple_batches_in_order(self): for predictionExpected, predictionActual in zip(predictionsExpected, list(predictionsActual)): for elementExpected, elementAcutal in zip(predictionExpected, predictionActual): self.assertAlmostEqual(elementExpected, elementAcutal, places=1) + + def test_fail_with_new_model_no_input_type(self): + text_batch = random_texts(cohere.COHERE_EMBED_BATCH_SIZE) + with pytest.raises(CohereError): + co.embed(model="embed-english-v3.0", texts=text_batch) + + def test_fail_with_new_model_invalid_input_type(self): + text_batch = random_texts(cohere.COHERE_EMBED_BATCH_SIZE) + input_type = "invalid" + with pytest.raises(CohereError): + co.embed(model="embed-english-v3.0", texts=text_batch, input_type=input_type) + + def test_success_with_new_model_and_input_type(self): + text = ["cohere"] + input_types = ["classification", "search_document", "search_query", "clustering"] + + for input_type in input_types: + prediction = co.embed(model="embed-english-v3.0", texts=text, input_type=input_type) + embed = prediction.embeddings[0] + self.assertIsInstance(embed, list) + self.assertEqual(len(embed), 1024) From 613d3b63b21412bd9c93185837da2afa9fc52a41 Mon Sep 17 00:00:00 2001 From: Alekhya Date: Wed, 20 Sep 2023 23:09:01 -0700 Subject: [PATCH 3/3] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d90abbd31..15255e0fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 4.28 + - [#310] (https://github.com/cohere-ai/cohere-python/pull/310) + - Embed: add input_type parameter for new embed models + ## 4.27 - [#308] (https://github.com/cohere-ai/cohere-python/pull/308) - Datasets: add validation_warnings