diff --git a/CHANGELOG.md b/CHANGELOG.md index b953da7b5..9809996fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 4.23 +- [#294] (https://github.com/cohere-ai/cohere-python/pull/294) + - Allow passing of ParseInfo for datasets + ## 4.22 - [#292] (https://github.com/cohere-ai/cohere-python/pull/292) - Add search query only parameter diff --git a/cohere/client.py b/cohere/client.py index 1aabbc476..a23e2fcdb 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -42,7 +42,7 @@ HyperParametersInput, ModelMetric, ) -from cohere.responses.dataset import BaseDataset, Dataset +from cohere.responses.dataset import BaseDataset, Dataset, ParseInfo from cohere.responses.detectlang import DetectLanguageResponse, Language from cohere.responses.embed_job import EmbedJob from cohere.responses.embeddings import Embeddings @@ -753,6 +753,7 @@ def create_dataset( dataset_type: str, keep_fields: Union[str, List[str]] = None, optional_fields: Union[str, List[str]] = None, + parse_info: Optional[ParseInfo] = None, ) -> Dataset: """Returns a Dataset given input data @@ -762,7 +763,7 @@ def create_dataset( dataset_type (str): The type of dataset you want to upload keep_fields (Union[str, List[str]]): (optional) A list of fields you want to keep in the dataset that are required optional_fields (Union[str, List[str]]): (optional) A list of fields you want to keep in the dataset that are optional - + parse_info: ParseInfo: (optional) information on how to parse the raw data Returns: Dataset: Dataset object. """ @@ -773,6 +774,9 @@ def create_dataset( "keep_fields": keep_fields, "optional_fields": optional_fields, } + if parse_info: + params.update(parse_info.get_params()) + logger.warning("uploading file, starting validation...") create_response = self._request(cohere.DATASET_URL, files=files, params=params) logger.warning(f"{create_response['id']} was uploaded") diff --git a/cohere/client_async.py b/cohere/client_async.py index 3e4adb464..e3cb2d9d3 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -54,7 +54,7 @@ HyperParametersInput, ModelMetric, ) -from cohere.responses.dataset import AsyncDataset, BaseDataset +from cohere.responses.dataset import AsyncDataset, BaseDataset, ParseInfo from cohere.responses.embed_job import AsyncEmbedJob from cohere.utils import async_wait_for_job, is_api_key_valid, np_json_dumps @@ -491,6 +491,7 @@ async def create_dataset( dataset_type: str, keep_fields: Union[str, List[str]] = None, optional_fields: Union[str, List[str]] = None, + parse_info: Optional[ParseInfo] = None, ) -> AsyncDataset: """Returns a Dataset given input data @@ -500,7 +501,7 @@ async def create_dataset( dataset_type (str): The type of dataset you want to upload keep_fields (Union[str, List[str]]): (optional) A list of fields you want to keep in the dataset that are required optional_fields (Union[str, List[str]]): (optional) A list of fields you want to keep in the dataset that are optional - + parse_info: ParseInfo: (optional) information on how to parse the raw data Returns: AsyncDataset: Dataset object. """ @@ -513,6 +514,8 @@ async def create_dataset( params["keep_fields"] = keep_fields if optional_fields: params["optional_fields"] = optional_fields + if parse_info: + params.update(parse_info.get_params()) logger.warning("uploading file, starting validation...") create_response = await self._request(cohere.DATASET_URL, files=files, params=params) diff --git a/cohere/responses/dataset.py b/cohere/responses/dataset.py index 240a5c3cd..5d8600370 100644 --- a/cohere/responses/dataset.py +++ b/cohere/responses/dataset.py @@ -1,5 +1,6 @@ import csv import json +from dataclasses import dataclass from datetime import datetime from typing import Any, Callable, Dict, List, Optional @@ -141,3 +142,14 @@ async def wait( updated_job = await self._wait_fn(dataset_id=self.id, timeout=timeout, interval=interval) self._update_self(updated_job) return updated_job + + +@dataclass +class ParseInfo: + separator: Optional[str] + + def get_params(self) -> Dict[str, str]: + params = {} + if self.separator: + params["text_separator"] = self.separator + return params diff --git a/pyproject.toml b/pyproject.toml index 0e12ba497..b9466e9a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cohere" -version = "4.22" +version = "4.23" description = "" authors = ["Cohere"] readme = "README.md"