diff --git a/autoblocks/_impl/datasets/client.py b/autoblocks/_impl/datasets/client.py index f7ac8141..1af25d08 100644 --- a/autoblocks/_impl/datasets/client.py +++ b/autoblocks/_impl/datasets/client.py @@ -7,6 +7,12 @@ from typing import List from typing import Optional +import httpx +from tenacity import retry +from tenacity import retry_if_exception_type +from tenacity import stop_after_attempt +from tenacity import wait_random_exponential + from autoblocks._impl.api.base_app_resource_client import BaseAppResourceClient from autoblocks._impl.api.exceptions import ValidationError from autoblocks._impl.api.utils.serialization import deserialize_model @@ -18,6 +24,7 @@ from autoblocks._impl.datasets.models.dataset import DatasetSchema from autoblocks._impl.datasets.models.dataset import SuccessResponse from autoblocks._impl.datasets.models.schema import create_schema_property +from autoblocks._impl.datasets.util import validate_unique_property_names from autoblocks._impl.util import cuid_generator log = logging.getLogger(__name__) @@ -29,6 +36,35 @@ class DatasetsClient(BaseAppResourceClient): def __init__(self, app_slug: str, api_key: str, timeout: timedelta = timedelta(seconds=60)) -> None: super().__init__(app_slug=app_slug, api_key=api_key, timeout=timeout) + def _process_schema(self, schema: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Process schema items and validate them. + + Args: + schema: List of property dictionaries + + Returns: + List of processed schema properties + + Raises: + ValidationError: If any schema property is invalid + """ + processed_schema = [] + + for i, prop_dict in enumerate(schema): + prop_dict_copy = dict(prop_dict) + + if "id" not in prop_dict_copy: + prop_dict_copy["id"] = cuid_generator() + + try: + schema_prop = create_schema_property(prop_dict_copy) + processed_schema.append(serialize_model(schema_prop)) + except Exception as e: + raise ValidationError(f"Invalid schema property at index {i}: {str(e)}") + + return processed_schema + def list(self) -> List[Dataset]: """ List all datasets in the app. @@ -59,27 +95,14 @@ def create( Raises: ValidationError: If required parameters are missing or invalid """ + # Validate unique property names + validate_unique_property_names(schema) + # Construct the basic dataset data data: Dict[str, Any] = {"name": name} # Process schema items - processed_schema = [] - - for i, prop_dict in enumerate(schema): - prop_dict_copy = dict(prop_dict) - - if "id" not in prop_dict_copy: - prop_dict_copy["id"] = cuid_generator() - - try: - # 3. If valid, add to processed schema - schema_prop = create_schema_property(prop_dict_copy) - processed_schema.append(serialize_model(schema_prop)) - except Exception as e: - raise ValidationError(f"Invalid schema property at index {i}: {str(e)}") - - # Use the field alias to ensure it's sent as 'schema' to the API - data["schema"] = processed_schema + data["schema"] = self._process_schema(schema) # Make the API call path = self._build_app_path("datasets") @@ -110,6 +133,11 @@ def destroy( response = self._make_request("DELETE", path) return deserialize_model(SuccessResponse, response) + @retry( + retry=retry_if_exception_type((httpx.ReadTimeout, httpx.ConnectTimeout, httpx.WriteTimeout)), + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=1, min=4, max=10), + ) def get_items( self, *, @@ -140,6 +168,11 @@ def get_items( response = self._make_request("GET", path) return deserialize_model_list(DatasetItem, response) + @retry( + retry=retry_if_exception_type((httpx.ReadTimeout, httpx.ConnectTimeout, httpx.WriteTimeout)), + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=1, min=4, max=10), + ) def create_items( self, *, @@ -234,6 +267,46 @@ def get_items_by_schema_version( response = self._make_request("GET", path) return deserialize_model_list(DatasetItem, response) + def update_dataset( + self, + *, + external_id: str, + name: Optional[str] = None, + schema: Optional[List[Dict[str, Any]]] = None, + ) -> Dataset: + """ + Update a dataset. + + Args: + external_id: Dataset ID (required) + name: New dataset name (optional) + schema: New schema as list of property dictionaries (optional) + + Returns: + Updated dataset + + Raises: + ValidationError: If dataset ID is not provided or schema has duplicate property names + """ + if not external_id: + raise ValidationError("External ID is required") + + data: Dict[str, Any] = {} + + if name is not None: + data["name"] = name + + if schema is not None: + # Validate unique property names + validate_unique_property_names(schema) + + # Process schema items + data["schema"] = self._process_schema(schema) + + path = self._build_app_path("datasets", external_id) + response = self._make_request("PUT", path, data) + return deserialize_model(Dataset, response) + def update_item( self, *, diff --git a/autoblocks/_impl/datasets/util.py b/autoblocks/_impl/datasets/util.py index 70bc5483..f42714ff 100644 --- a/autoblocks/_impl/datasets/util.py +++ b/autoblocks/_impl/datasets/util.py @@ -1,5 +1,8 @@ +from typing import Any +from typing import Dict from typing import List +from autoblocks._impl.api.exceptions import ValidationError from autoblocks._impl.util import parse_autoblocks_overrides @@ -10,3 +13,26 @@ def get_selected_datasets() -> List[str]: """ overrides = parse_autoblocks_overrides() return overrides.test_selected_datasets + + +def validate_unique_property_names(schema: List[Dict[str, Any]]) -> None: + """ + Validate that all property names in schema are unique. + + Args: + schema: List of property dictionaries + + Raises: + ValidationError: If duplicate property names are found or if any property has no name + """ + # Extract property names and filter out None values + property_names = [] + for i, prop in enumerate(schema): + name = prop.get("name") + if name is None: + raise ValidationError(f"Property at index {i} has no name") + property_names.append(name) + + # Check for duplicates + if len(property_names) != len(set(property_names)): + raise ValidationError("Schema property names must be unique") diff --git a/autoblocks/datasets/utils.py b/autoblocks/datasets/utils.py index 754d1539..3ce786ae 100644 --- a/autoblocks/datasets/utils.py +++ b/autoblocks/datasets/utils.py @@ -1,3 +1,4 @@ from autoblocks._impl.datasets.util import get_selected_datasets +from autoblocks._impl.datasets.util import validate_unique_property_names -__all__ = ["get_selected_datasets"] +__all__ = ["get_selected_datasets", "validate_unique_property_names"] diff --git a/tests/autoblocks/test_app_client_datasets.py b/tests/autoblocks/test_app_client_datasets.py index 2fe60930..094df407 100644 --- a/tests/autoblocks/test_app_client_datasets.py +++ b/tests/autoblocks/test_app_client_datasets.py @@ -1,5 +1,8 @@ """Tests for AutoblocksAppClient dataset deserialization with defaultValue fields.""" +import pytest + +from autoblocks._impl.api.exceptions import ValidationError from autoblocks._impl.config.constants import API_ENDPOINT_V2 from autoblocks._impl.datasets.models.schema import SchemaPropertyType from autoblocks._impl.datasets.models.schema import SelectProperty @@ -198,3 +201,49 @@ def test_dataset_schema_property_factory_function(): assert string_prop.type == SchemaPropertyType.STRING assert string_prop.required is True assert string_prop.default_value is None + + +def test_update_dataset_duplicate_property_names(httpx_mock): + """Test that updating dataset with duplicate property names raises ValidationError.""" + client = AutoblocksAppClient(app_slug="test-app", api_key="mock-api-key") + + schema = [ + {"id": "prop-1", "name": "duplicate", "type": "String", "required": False}, + {"id": "prop-2", "name": "duplicate", "type": "Number", "required": False}, + ] + + with pytest.raises(ValidationError, match="Schema property names must be unique"): + client.datasets.update_dataset(external_id="test-dataset", schema=schema) + + +def test_update_dataset_success(httpx_mock): + """Test successful dataset update.""" + httpx_mock.add_response( + url=f"{API_ENDPOINT_V2}/apps/test-app/datasets/test-dataset", + method="PUT", + status_code=200, + json={ + "id": "test-dataset", + "externalId": "test-dataset", + "name": "Updated Dataset", + "createdAt": "2023-01-01T00:00:00Z", + "latestRevisionId": "rev-2", + "schemaVersion": 2, + "schema": [{"id": "prop-1", "name": "new_property", "type": "String", "required": False}], + }, + match_headers={"Authorization": "Bearer mock-api-key"}, + ) + + client = AutoblocksAppClient(app_slug="test-app", api_key="mock-api-key") + + schema = [{"id": "prop-1", "name": "new_property", "type": "String", "required": False}] + + result = client.datasets.update_dataset(external_id="test-dataset", name="Updated Dataset", schema=schema) + + assert result.id == "test-dataset" + assert result.external_id == "test-dataset" + assert result.name == "Updated Dataset" + assert result.schema_version == 2 + assert result.schema_properties is not None + assert len(result.schema_properties) == 1 + assert result.schema_properties[0].name == "new_property" diff --git a/tests/e2e/test_datasets.py b/tests/e2e/test_datasets.py index adfc5684..788841a3 100644 --- a/tests/e2e/test_datasets.py +++ b/tests/e2e/test_datasets.py @@ -212,9 +212,6 @@ def test_create_and_retrieve_items_with_conversation_data( class TestDatasetItemsOperations: """Test operations on dataset items.""" - # Class variable to store the item ID between test functions - test_item_id = None - @pytest.fixture(scope="class") def client(self) -> AutoblocksAppClient: return create_app_client() @@ -328,13 +325,24 @@ def test_retrieve_items_from_dataset(self, client: AutoblocksAppClient, test_dat empty_items: List[DatasetItem] = client.datasets.get_items(external_id=test_dataset_id, splits=["nonexistent"]) assert len(empty_items) == 0 - # Store an item ID for update/delete tests in the class variable - TestDatasetItemsOperations.test_item_id = retrieved_items[0].id - def test_update_item_in_dataset(self, client: AutoblocksAppClient, test_dataset_id: str) -> None: """Test updating an item in the dataset.""" - # Make sure we have an item ID from the previous test - assert TestDatasetItemsOperations.test_item_id is not None + # Create an item first + items = [ + { + "Text Field": "Original text", + "Number Field": 50, + } + ] + + client.datasets.create_items(external_id=test_dataset_id, items=items, split_names=["train"]) + + # Get the created item ID + retrieved_items: List[DatasetItem] = client.datasets.get_items(external_id=test_dataset_id) + item_to_update = next( + (item for item in retrieved_items if item.data.get("Text Field") == "Original text"), None + ) + assert item_to_update is not None # Use the new keyword-only arguments update_data = { @@ -344,7 +352,7 @@ def test_update_item_in_dataset(self, client: AutoblocksAppClient, test_dataset_ update_result = client.datasets.update_item( external_id=test_dataset_id, - item_id=TestDatasetItemsOperations.test_item_id, + item_id=item_to_update.id, data=update_data, split_names=["validation"], ) @@ -352,10 +360,8 @@ def test_update_item_in_dataset(self, client: AutoblocksAppClient, test_dataset_ assert update_result.success is True # Verify the update - retrieved_items: List[DatasetItem] = client.datasets.get_items(external_id=test_dataset_id) - updated_item = next( - (item for item in retrieved_items if item.id == TestDatasetItemsOperations.test_item_id), None - ) + updated_items: List[DatasetItem] = client.datasets.get_items(external_id=test_dataset_id) + updated_item = next((item for item in updated_items if item.id == item_to_update.id), None) assert updated_item is not None assert updated_item.data["Text Field"] == "Updated sample text" @@ -364,8 +370,22 @@ def test_update_item_in_dataset(self, client: AutoblocksAppClient, test_dataset_ def test_delete_item_from_dataset(self, client: AutoblocksAppClient, test_dataset_id: str) -> None: """Test deleting an item from the dataset.""" - # Make sure we have an item ID from the previous test - assert TestDatasetItemsOperations.test_item_id is not None + # Create an item first + items = [ + { + "Text Field": "Item to delete", + "Number Field": 999, + } + ] + + client.datasets.create_items(external_id=test_dataset_id, items=items, split_names=["test"]) + + # Get the created item ID + retrieved_items: List[DatasetItem] = client.datasets.get_items(external_id=test_dataset_id) + item_to_delete = next( + (item for item in retrieved_items if item.data.get("Text Field") == "Item to delete"), None + ) + assert item_to_delete is not None # Get dataset state before deletion pre_delete_datasets = client.datasets.list() @@ -374,9 +394,7 @@ def test_delete_item_from_dataset(self, client: AutoblocksAppClient, test_datase pre_delete_revision_id = pre_delete_dataset.latest_revision_id # Use the new keyword-only arguments - delete_result = client.datasets.delete_item( - external_id=test_dataset_id, item_id=TestDatasetItemsOperations.test_item_id - ) + delete_result = client.datasets.delete_item(external_id=test_dataset_id, item_id=item_to_delete.id) assert delete_result.success is True @@ -387,11 +405,9 @@ def test_delete_item_from_dataset(self, client: AutoblocksAppClient, test_datase assert post_delete_dataset.latest_revision_id != pre_delete_revision_id # Verify the item is deleted - retrieved_items: List[DatasetItem] = client.datasets.get_items(external_id=test_dataset_id) + final_items: List[DatasetItem] = client.datasets.get_items(external_id=test_dataset_id) - deleted_item = next( - (item for item in retrieved_items if item.id == TestDatasetItemsOperations.test_item_id), None - ) + deleted_item = next((item for item in final_items if item.id == item_to_delete.id), None) assert deleted_item is None