Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 4.23
- [#294] (https://github.com/cohere-ai/cohere-python/pull/294)
- Allow passing of ParseInfo for datasets

## 4.22
- [#292] (https://github.com/cohere-ai/cohere-python/pull/292)
- Add search query only parameter
Expand Down
8 changes: 6 additions & 2 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
HyperParametersInput,
ModelMetric,
)
from cohere.responses.dataset import BaseDataset, Dataset
from cohere.responses.dataset import BaseDataset, Dataset, ParseInfo
from cohere.responses.detectlang import DetectLanguageResponse, Language
from cohere.responses.embed_job import EmbedJob
from cohere.responses.embeddings import Embeddings
Expand Down Expand Up @@ -753,6 +753,7 @@ def create_dataset(
dataset_type: str,
keep_fields: Union[str, List[str]] = None,
optional_fields: Union[str, List[str]] = None,
parse_info: Optional[ParseInfo] = None,
) -> Dataset:
"""Returns a Dataset given input data

Expand All @@ -762,7 +763,7 @@ def create_dataset(
dataset_type (str): The type of dataset you want to upload
keep_fields (Union[str, List[str]]): (optional) A list of fields you want to keep in the dataset that are required
optional_fields (Union[str, List[str]]): (optional) A list of fields you want to keep in the dataset that are optional

parse_info: ParseInfo: (optional) information on how to parse the raw data
Returns:
Dataset: Dataset object.
"""
Expand All @@ -773,6 +774,9 @@ def create_dataset(
"keep_fields": keep_fields,
"optional_fields": optional_fields,
}
if parse_info:
params.update(parse_info.get_params())

logger.warning("uploading file, starting validation...")
create_response = self._request(cohere.DATASET_URL, files=files, params=params)
logger.warning(f"{create_response['id']} was uploaded")
Expand Down
7 changes: 5 additions & 2 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
HyperParametersInput,
ModelMetric,
)
from cohere.responses.dataset import AsyncDataset, BaseDataset
from cohere.responses.dataset import AsyncDataset, BaseDataset, ParseInfo
from cohere.responses.embed_job import AsyncEmbedJob
from cohere.utils import async_wait_for_job, is_api_key_valid, np_json_dumps

Expand Down Expand Up @@ -491,6 +491,7 @@ async def create_dataset(
dataset_type: str,
keep_fields: Union[str, List[str]] = None,
optional_fields: Union[str, List[str]] = None,
parse_info: Optional[ParseInfo] = None,
) -> AsyncDataset:
"""Returns a Dataset given input data

Expand All @@ -500,7 +501,7 @@ async def create_dataset(
dataset_type (str): The type of dataset you want to upload
keep_fields (Union[str, List[str]]): (optional) A list of fields you want to keep in the dataset that are required
optional_fields (Union[str, List[str]]): (optional) A list of fields you want to keep in the dataset that are optional

parse_info: ParseInfo: (optional) information on how to parse the raw data
Returns:
AsyncDataset: Dataset object.
"""
Expand All @@ -513,6 +514,8 @@ async def create_dataset(
params["keep_fields"] = keep_fields
if optional_fields:
params["optional_fields"] = optional_fields
if parse_info:
params.update(parse_info.get_params())

logger.warning("uploading file, starting validation...")
create_response = await self._request(cohere.DATASET_URL, files=files, params=params)
Expand Down
12 changes: 12 additions & 0 deletions cohere/responses/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import json
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional

Expand Down Expand Up @@ -141,3 +142,14 @@ async def wait(
updated_job = await self._wait_fn(dataset_id=self.id, timeout=timeout, interval=interval)
self._update_self(updated_job)
return updated_job


@dataclass
class ParseInfo:
separator: Optional[str]

def get_params(self) -> Dict[str, str]:
params = {}
if self.separator:
params["text_separator"] = self.separator
return params
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cohere"
version = "4.22"
version = "4.23"
description = ""
authors = ["Cohere"]
readme = "README.md"
Expand Down