Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.14.30](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.30) - 2022-11-29

### Added
- Support for uploading track-level metrics to external evaluation functions using track_ref_ids

## [0.14.29](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.14.29) - 2022-11-22

### Added
Expand Down
12 changes: 6 additions & 6 deletions nucleus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
import tqdm.notebook as tqdm_notebook

self.tqdm_bar = tqdm_notebook.tqdm
self._connection = Connection(self.api_key, self.endpoint)
self.connection = Connection(self.api_key, self.endpoint)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tricky design decision here. Track needs access to at least a Connection object to be able to call the API to upload track metadata, preferably NucleusClient access for consistency with other classes like Model. However, a Validate instance is created from NucleusClient, and so attempting to import Track from ScenarioTest, from where it is created in get_items, will cause a circular import. The lessermost evil was to update Track.from_json to take a Connection instead of a NucleusClient.

The tradeoff here was exposing connection as a public property of NucleusClient so to access this property when creating Tracks in cases when only a NucleusClient is available. This should be safe, since a Connection consists of two public properties––api_key and endpoint.

Lmk if there are any issues with this approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is fine.

I think passing around a connection is a newer approach than passing the client. I actually don't see any reason for a class having to have full access to the client. We should probably refactor everything to pass around the connection object itself.

I don't see any harm in doing things this way so ... let's go ahead 🙂

self.validate = Validate(self.api_key, self.endpoint)

def __repr__(self):
Expand Down Expand Up @@ -1014,16 +1014,16 @@ def create_object_index(
)

def delete(self, route: str):
return self._connection.delete(route)
return self.connection.delete(route)

def get(self, route: str):
return self._connection.get(route)
return self.connection.get(route)

def post(self, payload: dict, route: str):
return self._connection.post(payload, route)
return self.connection.post(payload, route)

def put(self, payload: dict, route: str):
return self._connection.put(payload, route)
return self.connection.put(payload, route)

# TODO: Fix return type, can be a list as well. Brings on a lot of mypy errors ...
def make_request(
Expand Down Expand Up @@ -1054,7 +1054,7 @@ def make_request(
"Received defined payload with GET request! Will ignore payload"
)
payload = None
return self._connection.make_request(payload, route, requests_command, return_raw_response) # type: ignore
return self.connection.make_request(payload, route, requests_command, return_raw_response) # type: ignore

def _set_api_key(self, api_key):
"""Fetch API key from environment variable NUCLEUS_API_KEY if not set"""
Expand Down
2 changes: 1 addition & 1 deletion nucleus/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,7 +1863,7 @@ def tracks(self) -> List[Track]:
tracks_list = [
Track.from_json(
payload=track,
client=self._client,
connection=self._client.connection,
)
for track in response[TRACKS_KEY]
]
Expand Down
10 changes: 8 additions & 2 deletions nucleus/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ def from_json(
frames = [Frame.from_json(frame) for frame in frames_payload]
tracks_payload = payload.get(TRACKS_KEY, [])
tracks = (
[Track.from_json(track, client) for track in tracks_payload]
[
Track.from_json(track, connection=client.connection)
for track in tracks_payload
]
if client
else []
)
Expand Down Expand Up @@ -680,7 +683,10 @@ def from_json(
items = [DatasetItem.from_json(item) for item in items_payload]
tracks_payload = payload.get(TRACKS_KEY, [])
tracks = (
[Track.from_json(track, client) for track in tracks_payload]
[
Track.from_json(track, connection=client.connection)
for track in tracks_payload
]
if client
else []
)
Expand Down
10 changes: 5 additions & 5 deletions nucleus/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

if TYPE_CHECKING:
from . import NucleusClient
from . import Connection


@dataclass # pylint: disable=R0902
Expand All @@ -25,7 +25,7 @@ class Track: # pylint: disable=R0902
metadata: Arbitrary key/value dictionary of info to attach to this track.
"""

_client: "NucleusClient"
_connection: "Connection"
dataset_id: str
reference_id: str
metadata: Optional[dict] = None
Expand All @@ -41,10 +41,10 @@ def __eq__(self, other):
)

@classmethod
def from_json(cls, payload: dict, client: "NucleusClient"):
def from_json(cls, payload: dict, connection: "Connection"):
"""Instantiates track object from schematized JSON dict payload."""
return cls(
_client=client,
_connection=connection,
reference_id=str(payload[REFERENCE_ID_KEY]),
dataset_id=str(payload[DATASET_ID_KEY]),
metadata=payload.get(METADATA_KEY, None),
Expand Down Expand Up @@ -79,7 +79,7 @@ def update(
entire metadata object will be overwritten. Otherwise, only the keys in metadata will be overwritten.
"""

self._client.make_request(
self._connection.make_request(
payload={
REFERENCE_ID_KEY: self.reference_id,
METADATA_KEY: metadata,
Expand Down
2 changes: 1 addition & 1 deletion nucleus/validate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
]

from .client import Validate
from .constants import ThresholdComparison
from .constants import EntityLevel, ThresholdComparison
from .data_transfer_objects.eval_function import (
EvalFunctionEntry,
EvaluationCriterion,
Expand Down
2 changes: 1 addition & 1 deletion nucleus/validate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def create_external_eval_function(

Args:
name: unique name of evaluation function
level: level at which the eval function is run, defaults to "item"
level: level at which the eval function is run, defaults to EntityLevel.ITEM.

Raises:
- NucleusAPIError if the creation of the function fails on the server side
Expand Down
9 changes: 8 additions & 1 deletion nucleus/validate/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ class ThresholdComparison(str, Enum):


class EntityLevel(str, Enum):
"""Level for evaluation functions and unit tests."""
"""
Data level at which evaluation functions produce outputs.
For instance, when comparing results across dataset items, use
`EntityLevel.ITEM`. For scenes, use `EntityLevel.SCENE`. Finally,
when comparing results between tracks within a single scene or a
holistic item datset, use `EntityLevel.TRACK`.
"""

TRACK = "track"
ITEM = "item"
SCENE = "scene"
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


class EvaluationResult(ImmutableModel):
track_ref_id: Optional[str] = None
item_ref_id: Optional[str] = None
scene_ref_id: Optional[str] = None
score: float = 0
Expand All @@ -15,16 +16,15 @@ class EvaluationResult(ImmutableModel):
def is_item_or_scene_provided(
cls, values
): # pylint: disable=no-self-argument
if (
values.get("item_ref_id") is None
and values.get("scene_ref_id") is None
) or (
(
values.get("item_ref_id") is not None
and values.get("scene_ref_id") is not None
ref_ids = [
values.get("track_ref_id", None),
values.get("item_ref_id", None),
values.get("scene_ref_id", None),
]
if len([ref_id for ref_id in ref_ids if ref_id is not None]) != 1:
raise ValueError(
"Must provide exactly one of track_ref_id, item_ref_id, or scene_ref_id"
)
):
raise ValueError("Must provide either item_ref_id or scene_ref_id")
return values

@validator("score", "weight")
Expand Down
63 changes: 48 additions & 15 deletions nucleus/validate/scenario_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@
from typing import List, Optional, Union

from ..connection import Connection
from ..constants import DATASET_ITEMS_KEY, NAME_KEY, SCENES_KEY, SLICE_ID_KEY
from ..constants import (
DATASET_ITEMS_KEY,
NAME_KEY,
SCENES_KEY,
SLICE_ID_KEY,
TRACKS_KEY,
)
from ..dataset_item import DatasetItem
from ..scene import Scene
from ..track import Track
from .constants import (
EVAL_FUNCTION_ID_KEY,
SCENARIO_TEST_ID_KEY,
Expand Down Expand Up @@ -166,8 +173,8 @@ def get_eval_history(self) -> List[ScenarioTestEvaluation]:

def get_items(
self, level: EntityLevel = EntityLevel.ITEM
) -> Union[List[DatasetItem], List[Scene]]:
"""Gets items within a scenario test at a given level, returning a list of DatasetItem or Scene objects.
) -> Union[List[Track], List[DatasetItem], List[Scene]]:
"""Gets items within a scenario test at a given level, returning a list of Track, DatasetItem, or Scene objects.

Args:
level: :class:`EntityLevel`
Expand All @@ -178,14 +185,22 @@ def get_items(
response = self.connection.get(
f"validate/scenario_test/{self.id}/items",
)
if level == EntityLevel.TRACK:
return [
Track.from_json(track, connection=self.connection)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is the reason for the long justification above 😤

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, the struggle is real 🥇

for track in response.get(TRACKS_KEY, [])
]
if level == EntityLevel.SCENE:
return [
Scene.from_json(scene, skip_validate=True)
for scene in response[SCENES_KEY]
for scene in response.get(SCENES_KEY, [])
]
return [
DatasetItem.from_json(item) for item in response[DATASET_ITEMS_KEY]
]
if level == EntityLevel.ITEM:
return [
DatasetItem.from_json(item)
for item in response.get(DATASET_ITEMS_KEY, [])
]
raise ValueError(f"Invalid entity level: {level}")

def set_baseline_model(self, model_id: str):
"""Sets a new baseline model for the ScenarioTest. In order to be eligible to be a baseline,
Expand Down Expand Up @@ -222,23 +237,41 @@ def upload_external_evaluation_results(
len(results) > 0
), "Submitting evaluation requires at least one result."

level = EntityLevel.ITEM
level: Optional[EntityLevel] = None
metric_per_ref_id = {}
weight_per_ref_id = {}
aggregate_weighted_sum = 0.0
aggregate_weight = 0.0

# Ensures reults at only one EntityLevel are provided, otherwise throwing a ValueError
def ensure_level_consistency_or_raise(
cur_level: Optional[EntityLevel], new_level: EntityLevel
):
if level is not None and level != new_level:
raise ValueError(
f"All evaluation results must only pertain to one level. Received {cur_level} then {new_level}"
)

# aggregation based on https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
for r in results:
# Ensure results are uploaded ONLY for items or ONLY for scenes
# Ensure results are uploaded ONLY for ONE OF tracks, items, and scenes
if r.track_ref_id is not None:
ensure_level_consistency_or_raise(level, EntityLevel.TRACK)
level = EntityLevel.TRACK
if r.item_ref_id is not None:
ensure_level_consistency_or_raise(level, EntityLevel.ITEM)
level = EntityLevel.ITEM
if r.scene_ref_id is not None:
ensure_level_consistency_or_raise(level, EntityLevel.SCENE)
level = EntityLevel.SCENE
if r.item_ref_id is not None and level == EntityLevel.SCENE:
raise ValueError(
"All evaluation results must either pertain to a scene_ref_id or an item_ref_id, not both."
)
ref_id = (
r.item_ref_id if level == EntityLevel.ITEM else r.scene_ref_id
r.track_ref_id
if level == EntityLevel.TRACK
else (
r.item_ref_id
if level == EntityLevel.ITEM
else r.scene_ref_id
)
)

# Aggregate scores and weights
Expand All @@ -255,7 +288,7 @@ def upload_external_evaluation_results(
"overall_metric": aggregate_weighted_sum / aggregate_weight,
"model_id": model_id,
"slice_id": self.slice_id,
"level": level.value,
"level": level.value if level else None,
}
response = self.connection.post(
payload,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ exclude = '''

[tool.poetry]
name = "scale-nucleus"
version = "0.14.29"
version = "0.14.30"
description = "The official Python client library for Nucleus, the Data Platform for AI"
license = "MIT"
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]
Expand Down
1 change: 1 addition & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DATASET_WITH_EMBEDDINGS = "ds_c8jwdhy4y4f0078hzceg"
NUCLEUS_PYTEST_USER_ID = "60ad648c85db770026e9bf77"

EVAL_FUNCTION_NAME = "eval_fn"
EVAL_FUNCTION_THRESHOLD = 0.5
EVAL_FUNCTION_COMPARISON = ThresholdComparison.GREATER_THAN_EQUAL_TO

Expand Down
3 changes: 2 additions & 1 deletion tests/test_track.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from copy import deepcopy

import pytest
Expand Down Expand Up @@ -69,7 +70,7 @@ def test_create_mp_with_tracks(CLIENT, dataset_scene):
expected_track_reference_ids = [
ann["track_reference_id"] for ann in TEST_SCENE_BOX_PREDS_WITH_TRACK
]
model_reference = "model_test_create_mp_with_tracks"
model_reference = "model_" + str(time.time())
model = CLIENT.create_model(TEST_MODEL_NAME, model_reference)

# Act
Expand Down
Loading