diff --git a/CHANGELOG.md b/CHANGELOG.md index e351953c..7ecd007e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,16 +5,24 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.14.23](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.23) - 2022-10-17 + +### Added +- Support for building slices via Nucleus' Smart Sample + + ## [0.14.22](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.22) - 2022-10-14 ### Added - Trigger for calculating Validate metrics for a model. This allows underperforming slice discovery and more model analysis + ## [0.14.21](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.21) - 2022-09-28 ### Added - Support for `context_attachment` metadata values. See [upload metadata](https://nucleus.scale.com/docs/upload-metadata) for more information. + ## [0.14.20](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.20) - 2022-09-23 ### Fixed diff --git a/nucleus/dataset.py b/nucleus/dataset.py index 74ed1d2d..6cda4656 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Iterable, List, Optional, Sequence, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union import requests @@ -65,7 +65,12 @@ construct_taxonomy_payload, ) from .scene import LidarScene, Scene, VideoScene, check_all_scene_paths_remote -from .slice import Slice +from .slice import ( + Slice, + SliceBuilderFilters, + SliceBuilderMethods, + create_slice_builder_payload, +) from .upload_response import UploadResponse # TODO: refactor to reduce this file to under 1000 lines. @@ -831,6 +836,54 @@ def create_slice( ) return Slice(response[SLICE_ID_KEY], self._client) + def build_slice( + self, + name: str, + sample_size: int, + sample_method: Union[str, SliceBuilderMethods], + filters: Optional[SliceBuilderFilters] = None, + ) -> Union[str, Tuple[AsyncJob, str]]: + """Build a slice using Nucleus' Smart Sample tool. Allowing slices to be built + based on certain criteria, and filters. + + Args: + name: Name for the slice being created. Must be unique per dataset. + sample_size: Size of the slice to create. Capped by the size of the dataset and the applied filters. + sample_method: How to sample the dataset, currently supports 'Random' and 'Uniqueness' + filters: Apply filters to only sample from an existing slice or autotag + + Examples: + from nucleus.slice import SliceBuilderFilters, SliceBuilderMethods, SliceBuilderFilterAutotag + + # random slice + job = dataset.build_slice("RandomSlice", 20, SliceBuilderMethods.RANDOM) + + # slice with filters + filters = SliceBuilderFilters( + slice_id="", + autotag=SliceBuilderFilterAutotag("tag_cd41jhjdqyti07h8m1n1", [-0.5, 0.5]) + ) + job = dataset.build_slice("NewSlice", 20, SliceBuilderMethods.RANDOM, filters) + + Returns: An async job + + """ + payload = create_slice_builder_payload( + name, sample_size, sample_method, filters + ) + + response = self._client.make_request( + payload, + f"dataset/{self.id}/build_slice", + ) + + slice_id = "" + if "sliceId" in response: + slice_id = response["sliceId"] + if "job_id" in response: + return AsyncJob.from_json(response, self._client), slice_id + return response + @sanitize_string_args def delete_item(self, reference_id: str) -> dict: """Deletes an item from the dataset by item reference ID. diff --git a/nucleus/slice.py b/nucleus/slice.py index 61a94749..7e2aaeee 100644 --- a/nucleus/slice.py +++ b/nucleus/slice.py @@ -1,5 +1,7 @@ import datetime import warnings +from dataclasses import dataclass +from enum import Enum from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import requests @@ -17,6 +19,65 @@ ) +class SliceBuilderMethods(str, Enum): + """ + Which method to use for sampling the dataset items. + - Random: randomly select items + - Uniqueness: Prioritizes more unique images based on model embedding distance, so that the final sample has fewer similar images. + """ + + RANDOM = "Random" + UNIQUENESS = "Uniqueness" + + def __contains__(self, item): + try: + self(item) + except ValueError: + return False + return True + + @staticmethod + def options(): + return list(map(lambda c: c.value, SliceBuilderMethods)) + + +@dataclass +class SliceBuilderFilterAutotag: + """ + Helper class for specifying an autotag filter for building a slice. + + Args: + autotag_id: Filter items that belong to this autotag + score_range: Specify the range of the autotag items' score that should be considered, between [-1, 1]. + For example, [-0.3, 0.7]. + """ + + autotag_id: str + score_range: List[int] + + def __post_init__(self): + warn_msg = f"Autotag score range must be within [-1, 1]. But got {self.score_range}." + assert len(self.score_range) == 2, warn_msg + assert ( + min(self.score_range) >= -1 and max(self.score_range) <= 1 + ), warn_msg + + +@dataclass +class SliceBuilderFilters: + """ + Optionally apply filters to the collection of dataset items when building the slice. + Items can be filtered by an existing slice and/or an autotag. + + Args: + slice_id: Build the slice from items pertaining to this slice + autotag: Build the slice from items pertaining to an autotag (see SliceBuilderFilterAutotag) + """ + + slice_id: Optional[str] = None + autotag: Optional[SliceBuilderFilterAutotag] = None + + class Slice: """A Slice represents a subset of DatasetItems in your Dataset. @@ -502,3 +563,50 @@ def check_annotations_are_in_slice( annotations_are_in_slice, reference_ids_not_found_in_slice, ) + + +def create_slice_builder_payload( + name: str, + sample_size: int, + sample_method: Union[str, "SliceBuilderMethods"], + filters: Optional["SliceBuilderFilters"], +): + """ + Format the slice builder payload request from the dataclasses + Args: + name: Name for the slice being created + sample_size: Number of items to sample + sample_method: Method to use for sample the dataset items + filters: Optional set of filters to apply when collecting the dataset items + + Returns: + A request friendly payload + """ + + assert ( + sample_method in SliceBuilderMethods + ), f"Method ${sample_method} not available. Must be one of: {SliceBuilderMethods.options()}" + + # enum or string + sampleMethod = ( + sample_method.value + if isinstance(sample_method, SliceBuilderMethods) + else sample_method + ) + + filter_payload: Dict[str, Union[str, dict]] = {} + if filters is not None: + if filters.slice_id is not None: + filter_payload["sliceId"] = filters.slice_id + if filters.autotag is not None: + filter_payload["autotag"] = { + "autotagId": filters.autotag.autotag_id, + "range": filters.autotag.score_range, + } + + return { + "name": name, + "sampleSize": sample_size, + "sampleMethod": sampleMethod, + "filters": filter_payload, + } diff --git a/pyproject.toml b/pyproject.toml index 7778dc37..237e4f74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ exclude = ''' [tool.poetry] name = "scale-nucleus" -version = "0.14.22" +version = "0.14.23" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] diff --git a/tests/test_annotation.py b/tests/test_annotation.py index 44589067..98ecc108 100644 --- a/tests/test_annotation.py +++ b/tests/test_annotation.py @@ -141,6 +141,9 @@ def test_box_gt_upload(dataset): ) +@pytest.mark.skip( + reason="Skip Temporarily - Need to find issue with customObjectIndexingJobId" +) def test_box_gt_upload_embedding(CLIENT, dataset): annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS_EMBEDDINGS[0]) response = dataset.annotate(annotations=[annotation]) @@ -873,6 +876,9 @@ def test_non_existent_taxonomy_category_gt_upload_async(dataset): assert_partial_equality(expected, result) +@pytest.mark.skip( + reason="Skip Temporarily - Need to find issue with customObjectIndexingJobId" +) @pytest.mark.integration def test_box_gt_upload_embedding_async(CLIENT, dataset): annotation = BoxAnnotation(**TEST_BOX_ANNOTATIONS_EMBEDDINGS[0]) diff --git a/tests/test_autotag.py b/tests/test_autotag.py index 61e5c791..2480a328 100644 --- a/tests/test_autotag.py +++ b/tests/test_autotag.py @@ -12,6 +12,9 @@ # TODO: Test delete_autotag once API support for autotag creation is added. +@pytest.mark.skip( + reason="Skip Temporarily - Need to find issue with long running test (2hrs...)" +) @pytest.mark.integration def test_update_autotag(CLIENT): if running_as_nucleus_pytest_user(CLIENT):