diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 0902b667..89eb3f91 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.11 - uses: actions/cache@v4 with: key: ${{ github.ref }} diff --git a/README.md b/README.md index 3d9e39cc..bda81dab 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ See the [docs](https://cesnet.github.io/dp3/howto/get-started/) for more details ### Installing for application development -Pre-requisites: Python 3.9 or higher, `pip` (with `virtualenv` installed), `git`, `Docker` and `Docker Compose`. +Pre-requisites: Python 3.11 or higher, `pip` (with `virtualenv` installed), `git`, `Docker` and `Docker Compose`. Create a virtualenv and install the DP³ platform using: @@ -117,7 +117,7 @@ You are now ready to start developing your application! ## Installing for platform development -Pre-requisites: Python 3.9 or higher, `pip` (with `virtualenv` installed), `git`, `Docker` and `Docker Compose`. +Pre-requisites: Python 3.11 or higher, `pip` (with `virtualenv` installed), `git`, `Docker` and `Docker Compose`. Pull the repository and install using: diff --git a/docker/python/Dockerfile b/docker/python/Dockerfile index aea62099..426ae6d6 100644 --- a/docker/python/Dockerfile +++ b/docker/python/Dockerfile @@ -1,7 +1,7 @@ # syntax=docker/dockerfile:1 # Base interpreter with installed requirements -FROM python:3.9-slim AS base +FROM python:3.11-slim AS base RUN apt-get update; apt-get install -y \ gcc \ git diff --git a/docs/hooks.md b/docs/hooks.md index 530d6121..02d09d16 100644 --- a/docs/hooks.md +++ b/docs/hooks.md @@ -600,6 +600,8 @@ Code reference: [`scheduler_register`][dp3.common.callback_registrar.CallbackReg Most user-facing hooks return `list[DataPointTask]`. Whenever that happens, the returned tasks are fed back into the main ingestion system. +Each returned task must do useful work: it must contain at least one datapoint, carry non-empty TTL tokens, or be a delete task. +Empty `DataPointTask` objects are rejected during validation because they would be queued and processed without changing DP3 state. This creates a feedback loop: diff --git a/docs/howto/develop-dp3.md b/docs/howto/develop-dp3.md index 070cbffd..a29b5f31 100644 --- a/docs/howto/develop-dp3.md +++ b/docs/howto/develop-dp3.md @@ -16,7 +16,7 @@ You will end up with: For platform development, you need: -- Python 3.9 or higher +- Python 3.11 or higher - `pip` - `git` - Docker diff --git a/docs/howto/get-started.md b/docs/howto/get-started.md index 53c89818..ce6e6e19 100644 --- a/docs/howto/get-started.md +++ b/docs/howto/get-started.md @@ -15,7 +15,7 @@ You will end up with: For local application development, you need: -- Python 3.9 or higher +- Python 3.11 or higher - `pip` - `git` - Docker diff --git a/docs/howto/test-module.md b/docs/howto/test-module.md new file mode 100644 index 00000000..3f3a687a --- /dev/null +++ b/docs/howto/test-module.md @@ -0,0 +1,153 @@ +# Test a secondary module + +DP3 includes helpers for writing focused unit tests for secondary modules without running a full +worker, database, message broker, or snapshot scheduler. + +Use [`DP3ModuleTestCase`][dp3.testing.DP3ModuleTestCase] when you want to instantiate a +module with the application's real `db_entities` model and then call registered hooks directly. +The test registrar captures callbacks during module initialization and exposes runners for the +common hook families. + +The config directory is read from the `DP3_CONFIG_DIR` environment variable unless a test class +sets `config_dir` explicitly. Module configuration is read from `modules.` in that +config by default, where `` is inferred from the module class' Python module name. + +```bash +DP3_CONFIG_DIR=config python -m unittest discover -s tests -v +``` + +## Basic pattern + +```python +from unittest.mock import patch + +from dp3.testing import DP3ModuleTestCase +from modules.ip_exposure_profile import IPExposureProfile + + +class TestIPExposureProfile(DP3ModuleTestCase): + module_class = IPExposureProfile + + def test_open_port_creates_service_and_link(self): + dp = self.make_observation_datapoint("ip", "192.0.2.1", "open_ports", 443) + + tasks = self.run_on_new_attr("ip", "open_ports", "192.0.2.1", dp) + + self.assertDatapoint(tasks, etype="service", eid="192.0.2.1:443", attr="guessed_type") + self.assertDatapoint(tasks, etype="ip", eid="192.0.2.1", attr="services") + + def test_updater_uses_mocked_external_lookup(self): + with patch.object(self.module, "_fetch_service_intel", return_value={"risk": "high"}): + tasks = self.run_periodic_update( + "service", + "192.0.2.1:443", + {"guessed_type": "https"}, + hook_id="service_intel", + ) + + self.assertDatapoint(tasks, attr="external_risk", v="high") +``` + +## What the helper provides + +`DP3ModuleTestCase`: + +- loads `db_entities` from `DP3_CONFIG_DIR` or `config_dir` and builds a real `ModelSpec`, +- creates a minimal `PlatformConfig`, +- instantiates `module_class` with a test registrar, +- creates validated `DataPointTask` and plain, observation, or timeseries datapoint objects using + the loaded model, +- calls registered hooks directly, +- provides partial-match assertions for emitted tasks, datapoints, and mutated records. + +The helper is intended for module-level unit tests. It does not run a database, task queues, +worker processes, recursive task ingestion, or full linked snapshot loading. + +## Datapoint helpers + +Use the datapoint helpers to build values accepted by the loaded model specification: + +```python +plain = self.make_plain_datapoint("ip", "192.0.2.1", "hostname", "host.example") +observation = self.make_observation_datapoint("ip", "192.0.2.1", "open_ports", 443) +timeseries = self.make_timeseries_datapoint( + "ip", + "192.0.2.1", + "traffic", + {"packets": [1, 2, 3], "bytes": [100, 200, 300]}, +) +``` + +For regular timeseries attributes, `make_timeseries_datapoint()` infers `t2` from `t1`, the +configured `time_step`, and the number of samples when `t2` is not supplied. + +## Hook runners + +Common runners are available on the test case: + +- `run_allow_entity_creation(entity, eid, task=None)` +- `run_on_entity_creation(entity, eid, task=None)` +- `run_on_new_attr(entity, attr, eid, dp)` +- `run_correlation_hooks(entity_type, record, master_record=None)` +- `run_periodic_update(entity_type, eid, master_record, hook_id=None)` +- `run_periodic_eid_update(entity_type, eid, hook_id=None)` +- `run_scheduler_job(index_or_func)` + +Correlation tests pass the snapshot `record` explicitly. The record must contain `eid`. +Scheduler jobs can be selected by registration index, callable, or callable name. + +## Assertions + +Assertions use partial matching: only fields supplied in the expected values are checked. + +```python +self.assertDatapoint(tasks, etype="ip", attr="hostname", v="example.test") +self.assertTaskEmitted(tasks, etype="ip", eid="192.0.2.1") +self.assertNoTasks(tasks) +self.assertNoDatapoints(tasks) +self.assertRecordContains(record, exposure_score=10) +self.assertRecordAttr(record, "exposure_score", 10) +self.assertRecordUnchanged(before, after) +``` + +Snake-case aliases are also available: `assert_datapoint`, `assert_task_emitted`, +`assert_no_tasks`, `assert_no_datapoints`, `assert_record_contains`, `assert_record_attr`, and +`assert_record_unchanged`. + +## Registration assertions + +Use registration assertions when a test needs to verify callback coverage or dynamic hook +registration. + +```python +self.assert_registered("on_new_attr", entity="ip", attr="hostname") +self.assert_registered_once("correlation", entity_type="service") +self.assert_registered_attrs("service", expected_service_attrs) +self.assert_scheduler_registered(func="reload_ip_groups", minute="*/10") +``` + +`assert_scheduler_registered()` accepts scheduler fields such as `minute`, `hour`, and `second`, +along with `func` for matching the registered callable by object or function name. + +## Mocking external dependencies + +Patch external constructors or functions before module instantiation when the dependency is created +in `__init__` or `load_config`: + +```python +class TestDNSModule(DP3ModuleTestCase): + module_class = DNSModule + + def setUp(self): + self.resolver_patcher = patch("modules.dns_module.Resolver", FakeResolver) + self.resolver_patcher.start() + self.addCleanup(self.resolver_patcher.stop) + super().setUp() +``` + +If patching is not convenient, use a test subclass as `module_class` and override the module's +initialization or dependency construction while keeping the hook methods under test unchanged. + +Deprecated registrar methods (`register_entity_hook` and `register_attr_hook`) are supported by the +test registrar and emit `DeprecationWarning`. Prefer the modern registration methods in new module +code and tests. diff --git a/docs/modules.md b/docs/modules.md index e2882436..822322d4 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -172,6 +172,16 @@ and configuration, see the [updater configuration](configuration/updater.md) pag - [`scheduler_register(...)`](hooks.md#scheduler_register) — CRON-style module-level scheduled callback for maintenance, polling, housekeeping, or shared-state reloads. +## Testing modules + +Secondary modules can be unit-tested without running a full DP3 worker by using +[`DP3ModuleTestCase`][dp3.testing.DP3ModuleTestCase]. The helper loads an application's +real `db_entities` model from `DP3_CONFIG_DIR` or an explicit test fixture path, instantiates a +module with a test callback registrar, and lets tests call registered hooks directly with validated +`DataPointTask` and datapoint objects. + +See [Test a secondary module](howto/test-module.md) for examples and supported hook runners. + ## Running module code in a separate thread The module is free to run its own code in separate threads or processes. diff --git a/dp3/api/internal/config.py b/dp3/api/internal/config.py index d7ef4626..87ee3231 100644 --- a/dp3/api/internal/config.py +++ b/dp3/api/internal/config.py @@ -44,7 +44,7 @@ def validate(cls, v): try: # Validate and parse environmental variables - conf_env = ConfigEnv.parse_obj(os.environ) + conf_env = ConfigEnv.model_validate(os.environ) except ValidationError as e: config_error = any("CONF_DIR" in x["loc"] and len(x["loc"]) > 1 for x in e.errors()) env_error = any(len(x["loc"]) == 1 for x in e.errors()) diff --git a/dp3/api/internal/entity_response_models.py b/dp3/api/internal/entity_response_models.py index 89adba14..c483ae87 100644 --- a/dp3/api/internal/entity_response_models.py +++ b/dp3/api/internal/entity_response_models.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any from pydantic import BaseModel, Field, NonNegativeInt, PlainSerializer @@ -25,11 +25,11 @@ class EntityState(BaseModel): JsonVal = Annotated[Any, PlainSerializer(to_json_friendly, when_used="json")] LinkVal = dict[str, JsonVal] -PlainVal = Union[LinkVal, JsonVal] +PlainVal = LinkVal | JsonVal MultiVal = list[PlainVal] HistoryVal = list[dict[str, PlainVal]] -Dp3Val = Union[HistoryVal, MultiVal, PlainVal] +Dp3Val = HistoryVal | MultiVal | PlainVal EntityEidMasterRecord = dict[str, Dp3Val] @@ -45,7 +45,7 @@ class EntityEidList(BaseModel): Data does not include history of observations attributes and timeseries. """ - time_created: Optional[datetime] = None + time_created: datetime | None = None count: int data: EntityEidSnapshots diff --git a/dp3/api/internal/models.py b/dp3/api/internal/models.py index 65918730..83d660fa 100644 --- a/dp3/api/internal/models.py +++ b/dp3/api/internal/models.py @@ -1,4 +1,6 @@ -from typing import Annotated, Any, Literal, Optional, Union +from functools import reduce +from operator import or_ +from typing import Annotated, Any, Literal from pydantic import BaseModel, Field, TypeAdapter, create_model, model_validator @@ -26,10 +28,10 @@ class DataPoint(BaseModel): id: Any attr: str v: Any - t1: Optional[AwareDatetime] = None - t2: Optional[T2Datetime] = Field(None, validate_default=True) + t1: AwareDatetime | None = None + t2: T2Datetime | None = Field(None, validate_default=True) c: Annotated[float, Field(ge=0.0, le=1.0)] = 1.0 - src: Optional[str] = None + src: str | None = None @model_validator(mode="after") def validate_against_attribute(self): @@ -43,14 +45,14 @@ def validate_against_attribute(self): class EntityId(BaseModel): - """Dummy model for entity id + """Common interface for validated entity identifiers. Attributes: type: Entity type id: Entity ID """ - type: Literal["entity_type"] + type: str id: Any @@ -60,11 +62,11 @@ class EntityId(BaseModel): entity_id_models.append( create_model( f"EntityId{{{entity_type}}}", - __base__=BaseModel, + __base__=EntityId, type=(Literal[entity_type], Field(..., alias="etype")), id=(dtype, Field(..., alias="eid")), ) ) -EntityId = Annotated[Union[tuple(entity_id_models)], Field(discriminator="type")] # noqa: F811 -EntityIdAdapter = TypeAdapter(EntityId) +EntityIdType = Annotated[reduce(or_, entity_id_models), Field(discriminator="type")] +EntityIdAdapter = TypeAdapter(EntityIdType) diff --git a/dp3/api/routers/entity.py b/dp3/api/routers/entity.py index b952a4bc..209072c6 100644 --- a/dp3/api/routers/entity.py +++ b/dp3/api/routers/entity.py @@ -1,5 +1,5 @@ -from datetime import datetime -from typing import Annotated, Any, Optional +from datetime import UTC, datetime +from typing import Annotated, Any, cast from fastapi import APIRouter, Depends, HTTPException, Request from pydantic import Json, NonNegativeInt, ValidationError @@ -22,7 +22,7 @@ from dp3.common.attrspec import AttrType from dp3.common.datapoint import to_json_friendly from dp3.common.task import DataPointTask, task_context -from dp3.common.types import UTC, AwareDatetime +from dp3.common.types import AwareDatetime from dp3.database.database import DatabaseError @@ -33,10 +33,10 @@ async def check_etype(etype: str): return etype -async def parse_eid(etype: str, eid: str): +async def parse_eid(etype: str, eid: str) -> EntityId: """Middleware to parse EID""" try: - return EntityIdAdapter.validate_python({"etype": etype, "eid": eid}) + return cast(EntityId, EntityIdAdapter.validate_python({"etype": etype, "eid": eid})) except ValidationError as e: raise RequestValidationError(["path", "eid"], e.errors()[0]["msg"]) from e @@ -44,12 +44,12 @@ async def parse_eid(etype: str, eid: str): ParsedEid = Annotated[EntityId, Depends(parse_eid)] -def _parse_optional_eid(etype: str, eid: Optional[str]) -> Any: +def _parse_optional_eid(etype: str, eid: str | None) -> Any: """Parse optional entity id query parameter for entity-scoped endpoints.""" if eid is None: return None try: - return EntityIdAdapter.validate_python({"etype": etype, "eid": eid}).id + return cast(EntityId, EntityIdAdapter.validate_python({"etype": etype, "eid": eid})).id except ValidationError as e: raise RequestValidationError(["query", "eid"], e.errors()[0]["msg"]) from e @@ -72,7 +72,7 @@ def _raw_datapoint_to_response(raw_datapoint: dict[str, Any]) -> dict[str, Any]: def get_eid_master_record_handler( - e: EntityId, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None + e: EntityId, date_from: AwareDatetime | None = None, date_to: AwareDatetime | None = None ): """Handler for getting master record of EID""" # TODO: This is probably not the most efficient way. Maybe gather only @@ -97,8 +97,8 @@ def get_eid_master_record_handler( def get_eid_snapshots_handler( e: EntityId, - date_from: Optional[AwareDatetime] = None, - date_to: Optional[AwareDatetime] = None, + date_from: AwareDatetime | None = None, + date_to: AwareDatetime | None = None, skip: int = 0, limit: int = 0, ) -> list[dict[str, Any]]: @@ -271,9 +271,9 @@ async def count_entity_type_eids( ) async def get_entity_type_raw_datapoints( etype: str, - eid: Optional[str] = None, - attr: Optional[str] = None, - src: Optional[str] = None, + eid: str | None = None, + attr: str | None = None, + src: str | None = None, skip: NonNegativeInt = 0, limit: NonNegativeInt = 20, ) -> EntityRawDataPage: @@ -305,7 +305,7 @@ async def get_entity_type_raw_datapoints( @router.get("/{etype}/{eid}") async def get_eid_data( - e: ParsedEid, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None + e: ParsedEid, date_from: AwareDatetime | None = None, date_to: AwareDatetime | None = None ) -> EntityEidData: """Get data of the entity identified by `etype` and `eid`. @@ -325,7 +325,7 @@ async def get_eid_data( @router.get("/{etype}/{eid}/master") async def get_eid_master_record( - e: ParsedEid, date_from: Optional[AwareDatetime] = None, date_to: Optional[AwareDatetime] = None + e: ParsedEid, date_from: AwareDatetime | None = None, date_to: AwareDatetime | None = None ) -> EntityEidMasterRecord: """Get the master record of the entity identified by `etype` and `eid`.""" return get_eid_master_record_handler(e, date_from, date_to) @@ -334,8 +334,8 @@ async def get_eid_master_record( @router.get("/{etype}/{eid}/snapshots") async def get_eid_snapshots( e: ParsedEid, - date_from: Optional[AwareDatetime] = None, - date_to: Optional[AwareDatetime] = None, + date_from: AwareDatetime | None = None, + date_to: AwareDatetime | None = None, skip: NonNegativeInt = 0, limit: NonNegativeInt = 0, ) -> EntityEidSnapshots: @@ -351,8 +351,8 @@ async def get_eid_snapshots( async def get_eid_attr_value( e: ParsedEid, attr: str, - date_from: Optional[AwareDatetime] = None, - date_to: Optional[AwareDatetime] = None, + date_from: AwareDatetime | None = None, + date_to: AwareDatetime | None = None, ) -> EntityEidAttrValueOrHistory: """Get attribute value @@ -373,9 +373,7 @@ async def get_eid_attr_value( @router.post("/{etype}/{eid}/set/{attr}") -async def set_eid_attr_value( - etype: str, eid: str, attr: str, body: EntityEidAttrValue, request: Request -) -> SuccessResponse: +async def set_eid_attr_value(etype: str, eid: str, attr: str, request: Request) -> SuccessResponse: """Set current value of attribute Internally just creates datapoint for specified attribute and value. @@ -386,6 +384,11 @@ async def set_eid_attr_value( if attr not in MODEL_SPEC.attribs(etype): raise RequestValidationError(["path", "attr"], f"Attribute '{attr}' doesn't exist") + try: + body = EntityEidAttrValue.model_validate(await request.json()) + except ValueError as e: + raise RequestValidationError(["body"], str(e)) from e + # Construct datapoint try: dp = DataPoint( @@ -396,7 +399,7 @@ async def set_eid_attr_value( t1=datetime.now(UTC), src=f"{request.client.host} via API", ) - dp3_dp = api_to_dp3_datapoint(dp.dict()) + dp3_dp = api_to_dp3_datapoint(dp.model_dump()) except ValidationError as e: raise RequestValidationError(["body", "value"], e.errors()[0]["msg"]) from e diff --git a/dp3/api/routers/telemetry.py b/dp3/api/routers/telemetry.py index 60acd77f..0c2fe9d7 100644 --- a/dp3/api/routers/telemetry.py +++ b/dp3/api/routers/telemetry.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Literal, Optional +from typing import Literal import requests from fastapi import APIRouter, HTTPException @@ -41,9 +41,9 @@ async def get_snapshot_summary() -> dict: @router.get("/metadata") async def get_metadata( - module: Optional[str] = None, - date_from: Optional[AwareDatetime] = None, - date_to: Optional[AwareDatetime] = None, + module: str | None = None, + date_from: AwareDatetime | None = None, + date_to: AwareDatetime | None = None, skip: NonNegativeInt = 0, limit: NonNegativeInt = 0, sort: Literal["newest", "oldest"] = "newest", diff --git a/dp3/bin/check.py b/dp3/bin/check.py index 8836de6b..b852db1a 100755 --- a/dp3/bin/check.py +++ b/dp3/bin/check.py @@ -218,7 +218,7 @@ def main(args): unique_sources = [] source_paths_and_errors = [] - for path, source, err in zip(paths, sources, errors): + for path, source, err in zip(paths, sources, errors, strict=False): if source in unique_sources: i = unique_sources.index(source) source_paths_and_errors[i].add((path, err)) @@ -226,7 +226,7 @@ def main(args): unique_sources.append(source) source_paths_and_errors.append({(path, err)}) - for source, paths_and_errors in zip(unique_sources, source_paths_and_errors): + for source, paths_and_errors in zip(unique_sources, source_paths_and_errors, strict=False): for path, err in paths_and_errors: print(" -> ".join(path)) print(" ", err) diff --git a/dp3/bin/shcmd/common.py b/dp3/bin/shcmd/common.py index 860951fd..266bde98 100644 --- a/dp3/bin/shcmd/common.py +++ b/dp3/bin/shcmd/common.py @@ -5,7 +5,7 @@ import os import sys from functools import lru_cache -from typing import Any, Optional +from typing import Any from urllib.parse import urljoin import requests @@ -28,9 +28,9 @@ class DP3APIClient: def __init__( self, config_dir: str, - base_url: Optional[str] = None, + base_url: str | None = None, timeout: float = 5.0, - model_spec: Optional[ModelSpec] = None, + model_spec: ModelSpec | None = None, ): self.config_dir = os.path.abspath(config_dir) self.model_spec = model_spec @@ -42,7 +42,7 @@ def __init__( def _normalize_base_url(base_url: str) -> str: return base_url.rstrip("/") + "/" - def _resolve_base_url(self, base_url: Optional[str]) -> str: + def _resolve_base_url(self, base_url: str | None) -> str: if base_url is not None: normalized = self._normalize_base_url(base_url) self._check_health(normalized) @@ -84,7 +84,7 @@ def request( method: str, path: str, *, - params: Optional[dict[str, Any]] = None, + params: dict[str, Any] | None = None, json_body: Any = None, stream: bool = False, ) -> requests.Response: @@ -115,7 +115,7 @@ def read_json_value(raw_value: str) -> Any: raise APIError(f"Invalid JSON value: {e}") from e -def read_json_input(path: Optional[str]) -> Any: +def read_json_input(path: str | None) -> Any: """Decode JSON from a file path or standard input.""" if path in (None, "-"): content = sys.stdin.read() @@ -207,7 +207,7 @@ def stream_json_pages( return 0 -def resolve_config_dir(config_dir: Optional[str]) -> str: +def resolve_config_dir(config_dir: str | None) -> str: """Resolve the configuration directory for the shell-oriented CLI.""" if config_dir is not None: return os.path.abspath(config_dir) @@ -217,7 +217,7 @@ def resolve_config_dir(config_dir: Optional[str]) -> str: @lru_cache(maxsize=32) -def load_completion_model_spec(config_dir: str) -> Optional[ModelSpec]: +def load_completion_model_spec(config_dir: str) -> ModelSpec | None: """Load the model specification used by shell completion.""" try: config = read_config_dir(config_dir, recursive=True) @@ -228,8 +228,8 @@ def load_completion_model_spec(config_dir: str) -> Optional[ModelSpec]: @lru_cache(maxsize=32) def load_completion_entity_catalog( - config_dir: str, base_url: Optional[str], timeout: float -) -> Optional[dict[str, Any]]: + config_dir: str, base_url: str | None, timeout: float +) -> dict[str, Any] | None: """Load entity metadata from the API when config-based completion is unavailable.""" try: client = DP3APIClient(config_dir, base_url, timeout) @@ -241,7 +241,7 @@ def load_completion_entity_catalog( def get_completion_context( parsed_args, -) -> tuple[Optional[ModelSpec], Optional[dict[str, Any]]]: +) -> tuple[ModelSpec | None, dict[str, Any] | None]: """Return completion metadata derived from config and API sources.""" config_dir = resolve_config_dir(getattr(parsed_args, "config", None)) model_spec = load_completion_model_spec(config_dir) @@ -257,8 +257,8 @@ def get_completion_context( def _entity_type_description( etype: str, - model_spec: Optional[ModelSpec], - entity_catalog: Optional[dict[str, Any]], + model_spec: ModelSpec | None, + entity_catalog: dict[str, Any] | None, ) -> str: if model_spec is not None and etype in model_spec.entities: entity_spec = model_spec.entity(etype) diff --git a/dp3/bin/shcmd/entity/__init__.py b/dp3/bin/shcmd/entity/__init__.py index c8874972..46de9e9a 100644 --- a/dp3/bin/shcmd/entity/__init__.py +++ b/dp3/bin/shcmd/entity/__init__.py @@ -2,7 +2,6 @@ """Entity commands for the shell-oriented DP3 CLI.""" import argparse -from typing import Optional from . import etype from .common import complete_entity_rest, complete_entity_selector @@ -25,7 +24,7 @@ def _build_overview_parser() -> argparse.ArgumentParser: return parser -def parse_entity_command(args) -> tuple[Optional[argparse.Namespace], Optional[int]]: +def parse_entity_command(args) -> tuple[argparse.Namespace | None, int | None]: """Parse the path-like entity command grammar.""" overview_parser = _build_overview_parser() if args.selector is None: diff --git a/dp3/bin/shcmd/entity/common.py b/dp3/bin/shcmd/entity/common.py index acf620f2..0b3eb9a0 100644 --- a/dp3/bin/shcmd/entity/common.py +++ b/dp3/bin/shcmd/entity/common.py @@ -3,7 +3,7 @@ import argparse import json -from typing import Any, Optional +from typing import Any from argcomplete.finders import CompletionFinder @@ -154,7 +154,7 @@ def build_has_attr_filter(client, etype: str, attr: str) -> dict[str, Any]: return query -def build_generic_filter_param(client, args) -> Optional[str]: +def build_generic_filter_param(client, args) -> str | None: """Build the generic-filter query parameter for entity type queries.""" query = None if getattr(args, "filter_json", None) is not None: @@ -240,7 +240,7 @@ def _match_descriptions(values: dict[str, str], prefix: str) -> dict[str, str]: return {value: description for value, description in values.items() if value.startswith(prefix)} -def _entity_types(model_spec, entity_catalog: Optional[dict[str, Any]] = None) -> list[str]: +def _entity_types(model_spec, entity_catalog: dict[str, Any] | None = None) -> list[str]: if model_spec is not None: return sorted(model_spec.entities) if entity_catalog is not None: @@ -249,7 +249,7 @@ def _entity_types(model_spec, entity_catalog: Optional[dict[str, Any]] = None) - def _entity_attrs( - model_spec, etype: str, entity_catalog: Optional[dict[str, Any]] = None + model_spec, etype: str, entity_catalog: dict[str, Any] | None = None ) -> list[str]: if model_spec is not None and etype in model_spec.entities: return sorted(model_spec.attribs(etype)) @@ -259,7 +259,7 @@ def _entity_attrs( def _entity_attr_descriptions( - model_spec, etype: str, entity_catalog: Optional[dict[str, Any]] = None + model_spec, etype: str, entity_catalog: dict[str, Any] | None = None ) -> dict[str, str]: attrs = _entity_attrs(model_spec, etype, entity_catalog) descriptions = {attr: f"Attribute on entity type '{etype}'." for attr in attrs} diff --git a/dp3/common/attrspec.py b/dp3/common/attrspec.py index 79536837..86d26efb 100644 --- a/dp3/common/attrspec.py +++ b/dp3/common/attrspec.py @@ -1,6 +1,6 @@ from datetime import timedelta from enum import Flag -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal from pydantic import ( BaseModel, @@ -68,11 +68,11 @@ def from_str(cls, type_str: str): class ObservationsHistoryParams(BaseModel): """History parameters field of observations attribute""" - max_age: Optional[ParsedTimedelta] = None - max_items: Optional[PositiveInt] = None - expire_time: Optional[ParsedTimedelta] = None - pre_validity: Optional[ParsedTimedelta] = timedelta() - post_validity: Optional[ParsedTimedelta] = timedelta() + max_age: ParsedTimedelta | None = None + max_items: PositiveInt | None = None + expire_time: ParsedTimedelta | None = None + pre_validity: ParsedTimedelta | None = timedelta() + post_validity: ParsedTimedelta | None = timedelta() aggregate: bool = True @@ -85,8 +85,8 @@ def expire_time_inf_transform(cls, v): class TimeseriesTSParams(BaseModel): """Timeseries parameters field of timeseries attribute""" - max_age: Optional[ParsedTimedelta] = None - time_step: Optional[ParsedTimedelta] = None + max_age: ParsedTimedelta | None = None + time_step: ParsedTimedelta | None = None class TimeseriesSeries(BaseModel): @@ -119,7 +119,7 @@ class AttrSpecGeneric(SpecModel, use_enum_values=True): id: str = Field(pattern=ID_REGEX) name: str description: str = "" - ttl: Optional[ParsedTimedelta] = timedelta() + ttl: ParsedTimedelta | None = timedelta() _dp_model = PrivateAttr() @@ -298,7 +298,7 @@ def add_default_series(cls, v, info: FieldValidationInfo): - [AttrSpecObservations][dp3.common.attrspec.AttrSpecObservations] - [AttrSpecTimeseries][dp3.common.attrspec.AttrSpecTimeseries] """ -AttrSpecType = Union[AttrSpecTimeseries, AttrSpecObservations, AttrSpecPlain] +AttrSpecType = AttrSpecTimeseries | AttrSpecObservations | AttrSpecPlain def AttrSpec(id: str, spec: dict[str, Any]) -> AttrSpecType: diff --git a/dp3/common/callback_registrar.py b/dp3/common/callback_registrar.py index 675b4fcb..9e23414f 100644 --- a/dp3/common/callback_registrar.py +++ b/dp3/common/callback_registrar.py @@ -1,7 +1,8 @@ import logging +from collections.abc import Callable from functools import partial, wraps from logging import Logger -from typing import Any, Callable, Union +from typing import Any from pydantic import BaseModel @@ -11,7 +12,7 @@ from dp3.common.datatype import AnyEidT from dp3.common.scheduler import Scheduler from dp3.common.state import SharedFlag -from dp3.common.task import DataPointTask +from dp3.common.task import DataPointTask, task_context from dp3.common.types import ParsedTimedelta from dp3.core.updater import Updater from dp3.snapshots.snapshooter import SnapShooter @@ -57,7 +58,8 @@ def on_entity_creation_in_snapshots( if not run_flag.isset(): return [] eid = record["eid"] - mock_task = DataPointTask(etype=etype, eid=eid, data_points=[]) + with task_context(model_spec, allow_empty_data_point_task=True): + mock_task = DataPointTask(etype=etype, eid=eid, data_points=[]) tasks = original_hook(eid, mock_task) write_datapoints_into_record(model_spec, tasks, record) return tasks @@ -66,7 +68,7 @@ def on_entity_creation_in_snapshots( def on_attr_change_in_snapshots( model_spec: ModelSpec, run_flag: SharedFlag, - original_hook: Callable[[AnyEidT, DataPointTask], Union[list[DataPointTask], None]], + original_hook: Callable[[AnyEidT, DataPointTask], list[DataPointTask] | None], etype: str, record: dict, ) -> list[DataPointTask]: @@ -74,7 +76,8 @@ def on_attr_change_in_snapshots( if not run_flag.isset(): return [] eid = record["eid"] - mock_task = DataPointTask(etype=etype, eid=eid, data_points=[]) + with task_context(model_spec, allow_empty_data_point_task=True): + mock_task = DataPointTask(etype=etype, eid=eid, data_points=[]) tasks = original_hook(eid, mock_task) if isinstance(tasks, list): write_datapoints_into_record(model_spec, tasks, record) @@ -141,16 +144,16 @@ def scheduler_register( self, func: Callable, *, - func_args: Union[list, tuple] = None, + func_args: list | tuple = None, func_kwargs: dict = None, - year: Union[int, str] = None, - month: Union[int, str] = None, - day: Union[int, str] = None, - week: Union[int, str] = None, - day_of_week: Union[int, str] = None, - hour: Union[int, str] = None, - minute: Union[int, str] = None, - second: Union[int, str] = None, + year: int | str = None, + month: int | str = None, + day: int | str = None, + week: int | str = None, + day_of_week: int | str = None, + hour: int | str = None, + minute: int | str = None, + second: int | str = None, timezone: str = "UTC", misfire_grace_time: int = 1, ) -> int: @@ -273,7 +276,7 @@ def register_entity_hook(self, hook_type: str, hook: Callable, entity: str): def register_on_new_attr_hook( self, - hook: Callable[[AnyEidT, DataPointType], Union[None, list[DataPointTask]]], + hook: Callable[[AnyEidT, DataPointType], None | list[DataPointTask]], entity: str, attr: str, refresh: SharedFlag = None, @@ -355,7 +358,7 @@ def register_timeseries_hook( def register_correlation_hook( self, - hook: Callable[[str, dict], Union[None, list[DataPointTask]]], + hook: Callable[[str, dict], None | list[DataPointTask]], entity_type: str, depends_on: list[list[str]], may_change: list[list[str]], @@ -387,7 +390,7 @@ def register_correlation_hook( def register_correlation_hook_with_master_record( self, - hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]], + hook: Callable[[str, dict, dict], None | list[DataPointTask]], entity_type: str, depends_on: list[list[str]], may_change: list[list[str]], diff --git a/dp3/common/config.py b/dp3/common/config.py index 2b538740..1e2c86f9 100644 --- a/dp3/common/config.py +++ b/dp3/common/config.py @@ -6,13 +6,12 @@ from collections.abc import Iterator from contextlib import contextmanager from contextvars import ContextVar -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any, TypeVar import yaml from pydantic import ( BaseModel, ConfigDict, - Extra, Field, NonNegativeInt, PositiveInt, @@ -32,6 +31,8 @@ from dp3.common.datatype import AnyEidT from dp3.common.entityspec import EntitySpec +_T = TypeVar("_T") + class NoDefault: pass @@ -50,7 +51,7 @@ def __repr__(self): def copy(self): return HierarchicalDict(dict.copy(self)) - def get(self, key, default=NoDefault): + def get(self, key: str, default: type[NoDefault] | _T = NoDefault) -> Any | _T: """ Key may be a path (in dot notation) into a hierarchy of dicts. For example `dictionary.get('abc.x.y')` @@ -157,7 +158,7 @@ def read_config_dir(dir_path: str, recursive: bool = False) -> HierarchicalDict: TimeInt = Annotated[int, Field(ge=0, le=59)] -class CronExpression(BaseModel, extra=Extra.forbid): +class CronExpression(BaseModel): """ Cron expression used for scheduling. Also support standard cron expressions, such as @@ -177,16 +178,18 @@ class CronExpression(BaseModel, extra=Extra.forbid): timezone: Timezone for time specification (default is UTC). """ - second: Optional[Union[TimeInt, CronStr]] = None - minute: Optional[Union[TimeInt, CronStr]] = None - hour: Optional[Union[TimeInt, CronStr]] = None + model_config = ConfigDict(extra="forbid") + + second: TimeInt | CronStr | None = None + minute: TimeInt | CronStr | None = None + hour: TimeInt | CronStr | None = None - day: Optional[Union[Annotated[int, Field(ge=1, le=31)], CronStr]] = None - day_of_week: Optional[Union[Annotated[int, Field(ge=0, le=6)], CronStr]] = None + day: Annotated[int, Field(ge=1, le=31)] | CronStr | None = None + day_of_week: Annotated[int, Field(ge=0, le=6)] | CronStr | None = None - week: Optional[int] = Field(default=None, ge=1, le=53) - month: Optional[int] = Field(default=None, ge=1, le=12) - year: Optional[str] = Field(default=None, pattern=r"^\d{4}$") + week: int | None = Field(default=None, ge=1, le=53) + month: int | None = Field(default=None, ge=1, le=12) + year: str | None = Field(default=None, pattern=r"^\d{4}$") timezone: str = "UTC" diff --git a/dp3/common/control.py b/dp3/common/control.py index c3ccf3f7..83b9d9f8 100644 --- a/dp3/common/control.py +++ b/dp3/common/control.py @@ -3,8 +3,8 @@ """ import logging +from collections.abc import Callable from enum import Enum -from typing import Callable from pydantic import BaseModel diff --git a/dp3/common/datapoint.py b/dp3/common/datapoint.py index ee8850c0..d6c4a945 100644 --- a/dp3/common/datapoint.py +++ b/dp3/common/datapoint.py @@ -1,5 +1,5 @@ from ipaddress import IPv4Address, IPv6Address -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer @@ -8,7 +8,7 @@ def to_json_friendly(v): - if isinstance(v, (IPv4Address, IPv6Address, MACAddress)): + if isinstance(v, IPv4Address | IPv6Address | MACAddress): return str(v) return v @@ -33,7 +33,7 @@ class DataPointBase(BaseModel, use_enum_values=True): etype: str eid: Annotated[Any, PlainSerializer(to_json_friendly, when_used="json")] = None attr: str - src: Optional[str] = None + src: str | None = None v: Annotated[Any, PlainSerializer(to_json_friendly, when_used="json")] = None c: Any = None t1: Any = None @@ -147,10 +147,10 @@ def dp_ts_root_validator_irregular_intervals(self): # Check time_first[i] <= time_last[i] assert all( - t[0] <= t[1] for t in zip(self.v.time_first, self.v.time_last) + t[0] <= t[1] for t in zip(self.v.time_first, self.v.time_last, strict=False) ), "'time_first[i] <= time_last[i]' isn't true for all 'i'" return self -DataPointType = Union[DataPointPlainBase, DataPointObservationsBase, DataPointTimeseriesBase] +DataPointType = DataPointPlainBase | DataPointObservationsBase | DataPointTimeseriesBase diff --git a/dp3/common/datatype.py b/dp3/common/datatype.py index fe8b80ce..65930e34 100644 --- a/dp3/common/datatype.py +++ b/dp3/common/datatype.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv6Address -from typing import Any, Optional, Union +from typing import Any from pydantic import ( BaseModel, @@ -38,12 +38,12 @@ "mac": MACAddress, "time": datetime, "special": Any, - "json": Union[Json[Any], dict, list], + "json": Json[Any] | dict | list, } eid_data_types = ["string", "int", "ipv4", "ipv6", "mac"] -AnyEidT = Union[str, int, IPv4Address, IPv6Address, MACAddress] +AnyEidT = str | int | IPv4Address | IPv6Address | MACAddress """Type alias for any of possible entity ID data types. Note that the type is determined based on the loaded entity configuration @@ -190,7 +190,7 @@ def _determine_value_validator(self): # Set (type, default value) for the key if k_optional: k = k[:-1] # Remove question mark from key - dict_spec[k] = (Optional[primitive_data_types[v]], None) + dict_spec[k] = (primitive_data_types[v] | None, None) else: dict_spec[k] = (primitive_data_types[v], ...) @@ -228,7 +228,7 @@ def determine_value_validator(self): return self._determine_value_validator() @property - def data_type(self) -> Union[type, BaseModel]: + def data_type(self) -> type | BaseModel: """Type for incoming value validation""" return self._data_type @@ -263,7 +263,7 @@ def mirror_link(self) -> bool: return self._mirror_link @property - def mirror_as(self) -> Union[str, None]: + def mirror_as(self) -> str | None: """If `mirror_link`, what is the name of the mirrored attribute""" return self._mirror_as diff --git a/dp3/common/entityspec.py b/dp3/common/entityspec.py index f0e0fba7..5992b075 100644 --- a/dp3/common/entityspec.py +++ b/dp3/common/entityspec.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Literal from pydantic import BaseModel, Field, PrivateAttr, model_validator @@ -62,7 +62,7 @@ class EntitySpec(SpecModel): id_data_type: EidDataType = EidDataType("string") name: str snapshot: bool - lifetime: Union[ImmortalLifetime, TimeToLiveLifetime, WeakLifetime] = Field( + lifetime: ImmortalLifetime | TimeToLiveLifetime | WeakLifetime = Field( default_factory=lambda: ImmortalLifetime(type="immortal"), discriminator="type" ) diff --git a/dp3/common/mac_address.py b/dp3/common/mac_address.py index 28100d48..936cfad6 100644 --- a/dp3/common/mac_address.py +++ b/dp3/common/mac_address.py @@ -1,4 +1,6 @@ -from typing import Any, Union +from __future__ import annotations + +from typing import Any from pydantic import GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema @@ -10,7 +12,7 @@ class MACAddress: Can be initialized from colon or comma separated string, or from raw bytes. """ - def __init__(self, mac: Union[bytes, str, "MACAddress"]): + def __init__(self, mac: bytes | str | MACAddress): if isinstance(mac, self.__class__): mac = mac.mac # type: ignore if not isinstance(mac, bytes) or len(mac) != 6: @@ -19,11 +21,11 @@ def __init__(self, mac: Union[bytes, str, "MACAddress"]): self.mac: bytes = mac @classmethod - def _validate(cls, value: Any) -> "MACAddress": + def _validate(cls, value: Any) -> MACAddress: return cls(value) @classmethod - def _serialize(cls, value: "MACAddress", info: Any) -> Any: + def _serialize(cls, value: MACAddress, info: Any) -> Any: if info.mode == "json": return str(value) return value @@ -33,7 +35,7 @@ def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> CoreSchema: base_schema = core_schema.no_info_after_validator_function( - cls._validate, handler(Union[str, bytes]) + cls._validate, handler(str | bytes) ) python_schema = core_schema.union_schema( @@ -52,7 +54,7 @@ def __get_pydantic_core_schema__( ) @staticmethod - def _parse_mac(mac: Union[bytes, str]) -> bytes: + def _parse_mac(mac: bytes | str) -> bytes: if isinstance(mac, str): mac = mac.encode() if not isinstance(mac, bytes): diff --git a/dp3/common/scheduler.py b/dp3/common/scheduler.py index 75b03ac8..031d2bfc 100644 --- a/dp3/common/scheduler.py +++ b/dp3/common/scheduler.py @@ -6,7 +6,7 @@ """ import logging -from typing import Callable, Union +from collections.abc import Callable from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.cron import CronTrigger @@ -39,16 +39,16 @@ def stop(self) -> None: def register( self, func: Callable, - func_args: Union[list, tuple] = None, + func_args: list | tuple = None, func_kwargs: dict = None, - year: Union[int, str] = None, - month: Union[int, str] = None, - day: Union[int, str] = None, - week: Union[int, str] = None, - day_of_week: Union[int, str] = None, - hour: Union[int, str] = None, - minute: Union[int, str] = None, - second: Union[int, str] = None, + year: int | str = None, + month: int | str = None, + day: int | str = None, + week: int | str = None, + day_of_week: int | str = None, + hour: int | str = None, + minute: int | str = None, + second: int | str = None, timezone: str = "UTC", misfire_grace_time: int = 1, ) -> int: diff --git a/dp3/common/task.py b/dp3/common/task.py index e2ece755..3c319666 100644 --- a/dp3/common/task.py +++ b/dp3/common/task.py @@ -1,12 +1,12 @@ import hashlib from abc import ABC, abstractmethod -from collections.abc import Iterator +from collections.abc import Callable, Iterator from contextlib import contextmanager from contextvars import ContextVar from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv6Address -from typing import Annotated, Any, Callable, Optional, Union +from typing import Annotated, Any from pydantic import ( AfterValidator, @@ -17,7 +17,9 @@ Tag, TypeAdapter, ValidationError, + ValidationInfo, field_validator, + model_validator, ) from pydantic_core.core_schema import FieldValidationInfo @@ -40,9 +42,16 @@ def HASH(key: str) -> int: @contextmanager -def task_context(model_spec: ModelSpec) -> Iterator[None]: +def task_context( + model_spec: ModelSpec, *, allow_empty_data_point_task: bool = False +) -> Iterator[None]: """Context manager for setting the `model_spec` context variable.""" - token = _init_context_var.set({"model_spec": model_spec}) + token = _init_context_var.set( + { + "model_spec": model_spec, + "allow_empty_data_point_task": allow_empty_data_point_task, + } + ) try: yield finally: @@ -141,7 +150,7 @@ class DataPointTask(Task): eid: Annotated[Any, PlainSerializer(to_json_friendly, when_used="json")] data_points: list[ValidatedDataPoint] = [] tags: list[Any] = [] - ttl_tokens: Optional[dict[str, datetime]] = None + ttl_tokens: dict[str, datetime] | None = None delete: bool = False def __init__(__pydantic_self__, **data: Any) -> None: @@ -182,6 +191,18 @@ def validate_eid(cls, v, info: FieldValidationInfo): else: raise AssertionError("Missing `model_spec` in context") + @model_validator(mode="after") + def validate_not_empty(self, info: ValidationInfo): + context = info.context + if context and context.get("allow_empty_data_point_task"): + return self + if not self.data_points and not self.ttl_tokens and not self.delete: + raise ValueError( + "DataPointTask must contain at least one datapoint, non-empty ttl_tokens, " + "or be a delete task." + ) + return self + def parse_data_point_task(task: str, model_spec: ModelSpec) -> DataPointTask: with task_context(model_spec): @@ -201,13 +222,11 @@ def get_discriminator_value(entity_tuple: tuple[str, Any]) -> str: EntityTuple = Annotated[ - Union[ - Annotated[tuple[str, str], Tag("string")], - Annotated[tuple[str, int], Tag("int")], - Annotated[tuple[str, IPv4Address], Tag("ipv4")], - Annotated[tuple[str, IPv6Address], Tag("ipv6")], - Annotated[tuple[str, MACAddress], Tag("mac")], - ], + Annotated[tuple[str, str], Tag("string")] + | Annotated[tuple[str, int], Tag("int")] + | Annotated[tuple[str, IPv4Address], Tag("ipv4")] + | Annotated[tuple[str, IPv6Address], Tag("ipv6")] + | Annotated[tuple[str, MACAddress], Tag("mac")], Discriminator(get_discriminator_value), ] @@ -238,8 +257,8 @@ def as_message(self) -> str: return self.model_dump_json() @staticmethod - def get_validator(model_spec: ModelSpec) -> Callable[[Union[str, bytes]], "Snapshot"]: - def json_validator(serialized: Union[str, bytes]) -> Snapshot: + def get_validator(model_spec: ModelSpec) -> Callable[[str | bytes], "Snapshot"]: + def json_validator(serialized: str | bytes) -> Snapshot: with entity_type_context(model_spec): return Snapshot.model_validate_json(serialized) diff --git a/dp3/common/types.py b/dp3/common/types.py index 922963da..ab5f81fe 100644 --- a/dp3/common/types.py +++ b/dp3/common/types.py @@ -1,7 +1,7 @@ -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from ipaddress import IPv4Address, IPv6Address from json import JSONEncoder -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any from event_count_logger import DummyEventGroup, EventGroup from pydantic import AfterValidator, BeforeValidator @@ -9,8 +9,6 @@ from dp3.common.utils import parse_time_duration, time_duration_pattern -UTC = timezone.utc - def parse_timedelta_or_passthrough(v): """ @@ -24,7 +22,7 @@ def parse_timedelta_or_passthrough(v): ParsedTimedelta = Annotated[timedelta, BeforeValidator(parse_timedelta_or_passthrough)] -def ensure_timezone_aware(v: Optional[datetime]): +def ensure_timezone_aware(v: datetime | None): """Ensure datetime is timezone-aware by defaulting to UTC.""" if v is None: return v @@ -55,7 +53,7 @@ def t2_after_t1(v, info: FieldValidationInfo): AfterValidator(t2_after_t1), ] -EventGroupType = Union[EventGroup, DummyEventGroup] +EventGroupType = EventGroup | DummyEventGroup class DP3Encoder(JSONEncoder): @@ -64,6 +62,6 @@ class DP3Encoder(JSONEncoder): def default(self, o: Any) -> Any: if isinstance(o, datetime): return o.strftime("%Y-%m-%dT%H:%M:%S.%fZ")[:-4] - if isinstance(o, (IPv4Address, IPv6Address)): + if isinstance(o, IPv4Address | IPv6Address): return str(o) return super().default(o) diff --git a/dp3/common/utils.py b/dp3/common/utils.py index 742bb258..52a2371e 100644 --- a/dp3/common/utils.py +++ b/dp3/common/utils.py @@ -7,7 +7,6 @@ from datetime import datetime, timedelta from functools import partial from itertools import islice -from typing import Union # *** IP conversion functions *** ipv4_re = re.compile(r"^([0-9]{1,3})\.([0-9]{1,3})\.([0-9]{1,3})\.([0-9]{1,3})$") @@ -74,7 +73,7 @@ def parse_rfc_time(time_str): time_duration_pattern = re.compile(r"^\s*(\d+)([smhd])?$") -def parse_time_duration(duration_string: Union[str, int, timedelta]) -> timedelta: +def parse_time_duration(duration_string: str | int | timedelta) -> timedelta: """ Parse duration in format (or just "0"). @@ -84,7 +83,7 @@ def parse_time_duration(duration_string: Union[str, int, timedelta]) -> timedelt if isinstance(duration_string, timedelta): return duration_string # if number is passed, consider it number of seconds - if isinstance(duration_string, (int, float)): + if isinstance(duration_string, int | float): return timedelta(seconds=duration_string) d = 0 diff --git a/dp3/core/collector.py b/dp3/core/collector.py index f6ea9785..31801f7f 100644 --- a/dp3/core/collector.py +++ b/dp3/core/collector.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from functools import partial from pydantic import BaseModel @@ -15,7 +15,6 @@ from dp3.common.datapoint import DataPointBase, DataPointObservationsBase, DataPointTimeseriesBase from dp3.common.datatype import AnyEidT from dp3.common.task import DataPointTask, parse_eids_from_cache -from dp3.common.types import UTC from dp3.database.database import EntityDatabase DB_SEND_CHUNK = 1000 diff --git a/dp3/core/link_manager.py b/dp3/core/link_manager.py index 94aedc09..a0f60393 100644 --- a/dp3/core/link_manager.py +++ b/dp3/core/link_manager.py @@ -3,7 +3,7 @@ """ import logging -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from functools import partial from pymongo import DeleteMany @@ -14,7 +14,6 @@ from dp3.common.datapoint import DataPointBase, DataPointObservationsBase from dp3.common.datatype import AnyEidT from dp3.common.task import parse_eids_from_cache -from dp3.common.types import UTC from dp3.database.database import EntityDatabase diff --git a/dp3/core/updater.py b/dp3/core/updater.py index 1bd507b8..bccb1fa6 100644 --- a/dp3/core/updater.py +++ b/dp3/core/updater.py @@ -2,10 +2,10 @@ import logging from collections import defaultdict -from collections.abc import Iterator -from datetime import datetime, timedelta +from collections.abc import Callable, Iterator +from datetime import UTC, datetime, timedelta from functools import partial -from typing import Callable, Literal +from typing import Literal from pydantic import BaseModel, validate_call from pymongo.cursor import Cursor @@ -14,7 +14,7 @@ from dp3.common.config import CronExpression, PlatformConfig from dp3.common.scheduler import Scheduler from dp3.common.task import DataPointTask, task_context -from dp3.common.types import UTC, EventGroupType, ParsedTimedelta +from dp3.common.types import EventGroupType, ParsedTimedelta from dp3.database.database import EntityDatabase from dp3.task_processing.task_queue import TaskQueueWriter diff --git a/dp3/database/config.py b/dp3/database/config.py index 3137736e..573cce01 100644 --- a/dp3/database/config.py +++ b/dp3/database/config.py @@ -1,5 +1,5 @@ import urllib -from typing import Literal, Union +from typing import Literal from pydantic import BaseModel, Field, field_validator @@ -38,7 +38,7 @@ class MongoConfig(BaseModel, extra="forbid"): db_name: str = "dp3" username: str = "dp3" password: str = "dp3" - connection: Union[MongoStandaloneConfig, MongoReplicaConfig] = Field(..., discriminator="mode") + connection: MongoStandaloneConfig | MongoReplicaConfig = Field(..., discriminator="mode") storage: StorageConfig = StorageConfig() @field_validator("username", "password") diff --git a/dp3/database/database.py b/dp3/database/database.py index e676bed6..4a2bfede 100644 --- a/dp3/database/database.py +++ b/dp3/database/database.py @@ -3,9 +3,8 @@ import threading import time from collections import defaultdict -from collections.abc import Generator, Iterator -from datetime import datetime -from typing import Callable, Optional +from collections.abc import Callable, Generator, Iterator +from datetime import UTC, datetime import pymongo from event_count_logger import DummyEventGroup @@ -24,7 +23,7 @@ from dp3.common.datatype import AnyEidT from dp3.common.scheduler import Scheduler from dp3.common.task import HASH -from dp3.common.types import UTC, EventGroupType +from dp3.common.types import EventGroupType from dp3.database.config import MongoConfig, MongoReplicaConfig, MongoStandaloneConfig from dp3.database.encodings import get_codec_options from dp3.database.exceptions import DatabaseError @@ -62,7 +61,7 @@ def __init__( model_spec: ModelSpec, num_processes: int, process_index: int = 0, - elog: Optional[EventGroupType] = None, + elog: EventGroupType | None = None, ) -> None: self.log = logging.getLogger("EntityDatabase") self.elog = elog or DummyEventGroup() @@ -559,7 +558,7 @@ def update_master_records(self, etype: str, eids: list[AnyEidT], records: list[d res = master_col.bulk_write( [ ReplaceOne({"_id": eid}, record, upsert=True) - for eid, record in zip(eids, records) + for eid, record in zip(eids, records, strict=False) ], ordered=False, ) @@ -776,7 +775,7 @@ def delete_many_link_dps( try: updates = [] for etype, affected_eid_list, attr_name, eid_to_list in zip( - etypes, affected_eids, attr_names, eids_to + etypes, affected_eids, attr_names, eids_to, strict=False ): master_col = self._master_col(etype) attr_type = self._db_schema_config.attr(etype, attr_name).t @@ -836,8 +835,8 @@ def get_value_or_history( etype: str, attr_name: str, eid: AnyEidT, - t1: Optional[datetime] = None, - t2: Optional[datetime] = None, + t1: datetime | None = None, + t2: datetime | None = None, ) -> dict: """Gets current value and/or history of attribute for given `eid`. @@ -878,12 +877,12 @@ def estimate_count_eids(self, etype: str) -> int: master_col = self._master_col(etype) return master_col.estimated_document_count({}) - def _get_metadata_id(self, module: str, time: datetime, worker_id: Optional[int] = None) -> str: + def _get_metadata_id(self, module: str, time: datetime, worker_id: int | None = None) -> str: """Generates unique metadata id based on `module`, `time` and the worker index.""" worker_id = self._process_index if worker_id is None else worker_id return f"{module}{time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}w{worker_id}" - def save_metadata(self, time: datetime, metadata: dict, worker_id: Optional[int] = None): + def save_metadata(self, time: datetime, metadata: dict, worker_id: int | None = None): """Saves metadata dict under the caller module and passed timestamp.""" module = get_caller_id() metadata["_id"] = self._get_metadata_id(module, time, worker_id) @@ -897,7 +896,7 @@ def save_metadata(self, time: datetime, metadata: dict, worker_id: Optional[int] raise DatabaseError(f"Insert of metadata failed: {e}\n{metadata}") from e def update_metadata( - self, time: datetime, metadata: dict, increase: dict = None, worker_id: Optional[int] = None + self, time: datetime, metadata: dict, increase: dict = None, worker_id: int | None = None ): """Updates existing metadata of caller module and passed timestamp.""" module = get_caller_id() @@ -1120,7 +1119,7 @@ def move_raw_to_archive(self, etype: str): except Exception as e: raise DatabaseError(f"Move of raw collection failed: {e}") from e - def get_archive_summary(self, etype: str, before: datetime) -> Optional[dict]: + def get_archive_summary(self, etype: str, before: datetime) -> dict | None: collection_summaries = [] for archive_col_name in self._archive_col_names(etype): result_cursor = self._get_archive_summary(archive_col_name, before=before) @@ -1187,7 +1186,7 @@ def drop_empty_archives(self, etype: str) -> int: raise DatabaseError(f"Drop of empty archive failed: {e}") from e return dropped_count - def get_module_cache(self, override_called_id: Optional[str] = None): + def get_module_cache(self, override_called_id: str | None = None): """Return a persistent cache collection for given module name. Module name is determined automatically, but you can override it. diff --git a/dp3/database/magic.py b/dp3/database/magic.py index 4e2a856e..e4f7a17e 100644 --- a/dp3/database/magic.py +++ b/dp3/database/magic.py @@ -58,9 +58,9 @@ """ import re -from datetime import datetime, timezone +from datetime import UTC, datetime from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network -from typing import Any, Union +from typing import Any from bson import Binary @@ -87,9 +87,9 @@ def _binary_id_filter(value: Any, _) -> dict[str, Binary]: magic_type, value = match.groups() value = magic_string_replacements[magic_type](value, True) - if isinstance(value, (IPv4Address, IPv6Address, MACAddress)): + if isinstance(value, IPv4Address | IPv6Address | MACAddress): return _binary_snapshot_bucket_range(value.packed) - if isinstance(value, (IPv4Network, IPv6Network)): + if isinstance(value, IPv4Network | IPv6Network): return { "$gte": _pack_binary_snapshot_bucket_id(value[0].packed, 0), "$lte": _pack_binary_snapshot_bucket_id(value[-1].packed, -1), @@ -100,18 +100,14 @@ def _binary_id_filter(value: Any, _) -> dict[str, Binary]: raise ValueError(f"Unsupported value type {type(value)}: {value}") -def _parse_ipv4_network( - value: str, in_id_filter: bool -) -> Union[IPv4Network, dict[str, IPv4Address]]: +def _parse_ipv4_network(value: str, in_id_filter: bool) -> IPv4Network | dict[str, IPv4Address]: ip = IPv4Network(value) if in_id_filter: return ip return {"$gte": ip[0], "$lte": ip[-1]} -def _parse_ipv6_network( - value: str, in_id_filter: bool -) -> Union[IPv6Network, dict[str, IPv6Address]]: +def _parse_ipv6_network(value: str, in_id_filter: bool) -> IPv6Network | dict[str, IPv6Address]: ip = IPv6Network(value) if in_id_filter: return ip @@ -122,12 +118,12 @@ def _parse_mac_address(value: str, _) -> MACAddress: return MACAddress(value) -def _parse_date_ts(value: Union[int, float], _) -> datetime: - return datetime.fromtimestamp(float(value), timezone.utc) +def _parse_date_ts(value: int | float, _) -> datetime: + return datetime.fromtimestamp(float(value), UTC) def _parse_date_string(value: str, _) -> datetime: - return datetime.strptime(value, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + return datetime.strptime(value, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=UTC) magic_string_replacements = { @@ -153,7 +149,7 @@ def search_and_replace(query: dict[str, Any]) -> dict[str, Any]: """ if isinstance(query, dict): for key, value in query.items(): - if isinstance(value, (dict, list)): + if isinstance(value, dict | list): search_and_replace(value) elif isinstance(value, str): match = magic_regex.match(value) diff --git a/dp3/database/schema_cleaner.py b/dp3/database/schema_cleaner.py index 09ada81b..ae30e73a 100644 --- a/dp3/database/schema_cleaner.py +++ b/dp3/database/schema_cleaner.py @@ -1,9 +1,9 @@ import logging import time from collections import defaultdict -from datetime import datetime +from collections.abc import Callable +from datetime import UTC, datetime from logging import Logger -from typing import Callable import pymongo from pymongo import DeleteOne, InsertOne @@ -12,7 +12,6 @@ from dp3.common.attrspec import ID_REGEX, AttrSpecType, AttrType from dp3.common.config import HierarchicalDict, ModelSpec -from dp3.common.types import UTC from dp3.common.utils import batched # number of seconds to wait for the i-th attempt to reconnect after error diff --git a/dp3/database/snapshots.py b/dp3/database/snapshots.py index 1e8a5407..79b1e536 100644 --- a/dp3/database/snapshots.py +++ b/dp3/database/snapshots.py @@ -6,7 +6,7 @@ from collections.abc import Iterable from datetime import datetime, timedelta from ipaddress import IPv4Address, IPv6Address -from typing import Any, Optional, Union +from typing import Any import pymongo from bson import Binary @@ -125,11 +125,11 @@ def _binary_bucket_range(self, eid: bytes) -> dict: } @abc.abstractmethod - def _bucket_id(self, eid: AnyEidT, ctime: datetime) -> Union[str, Binary]: + def _bucket_id(self, eid: AnyEidT, ctime: datetime) -> str | Binary: """Returns `_id` for snapshot bucket document.""" @abc.abstractmethod - def _filter_from_bid(self, b_id: Union[bytes, str]) -> dict: + def _filter_from_bid(self, b_id: bytes | str) -> dict: """Returns filter for snapshots with same eid as given bucket document _id. Args: b_id: the _id of the snapshot bucket, type depends on etype's data type @@ -154,8 +154,8 @@ def get_latest_one(self, eid: AnyEidT) -> dict: def find_latest( self, - fulltext_filters: Optional[dict[str, str]] = None, - generic_filter: Optional[dict[str, Any]] = None, + fulltext_filters: dict[str, str] | None = None, + generic_filter: dict[str, Any] | None = None, ) -> Cursor: """Find latest snapshots of given `etype`. @@ -193,8 +193,8 @@ def find_latest( def count_latest( self, - fulltext_filters: Optional[dict[str, str]] = None, - generic_filter: Optional[dict[str, Any]] = None, + fulltext_filters: dict[str, str] | None = None, + generic_filter: dict[str, Any] | None = None, ) -> int: """Count latest snapshots of given `etype`. @@ -251,11 +251,11 @@ def _prepare_latest_query( def get_by_eid( self, eid: AnyEidT, - t1: Optional[datetime] = None, - t2: Optional[datetime] = None, + t1: datetime | None = None, + t2: datetime | None = None, skip: int = 0, limit: int = 0, - ) -> Union[Cursor, CommandCursor]: + ) -> Cursor | CommandCursor: """Get all (or filtered) snapshots of given `eid`. This method is useful for displaying `eid`'s history on web. @@ -360,8 +360,8 @@ def get_distinct_val_count(self, attr: str) -> dict[Any, int]: def _get_oversized( self, eid: AnyEidT, - t1: Optional[datetime] = None, - t2: Optional[datetime] = None, + t1: datetime | None = None, + t2: datetime | None = None, skip: int = 0, limit: int = 0, ) -> Cursor: @@ -789,8 +789,8 @@ def get_latest_one(self, entity_type: str, eid: AnyEidT) -> dict: def find_latest( self, entity_type: str, - fulltext_filters: Optional[dict[str, str]] = None, - generic_filter: Optional[dict[str, Any]] = None, + fulltext_filters: dict[str, str] | None = None, + generic_filter: dict[str, Any] | None = None, ) -> Cursor: """Find latest snapshots of given `etype`. @@ -825,8 +825,8 @@ def find_latest( def count_latest( self, entity_type: str, - fulltext_filters: Optional[dict[str, str]] = None, - generic_filter: Optional[dict[str, Any]] = None, + fulltext_filters: dict[str, str] | None = None, + generic_filter: dict[str, Any] | None = None, ) -> int: """Count latest snapshots of given `etype`. @@ -844,11 +844,11 @@ def get_by_eid( self, entity_type: str, eid: AnyEidT, - t1: Optional[datetime] = None, - t2: Optional[datetime] = None, + t1: datetime | None = None, + t2: datetime | None = None, skip: int = 0, limit: int = 0, - ) -> Union[Cursor, CommandCursor]: + ) -> Cursor | CommandCursor: """Get all (or filtered) snapshots of given `eid`. This method is useful for displaying `eid`'s history on web. diff --git a/dp3/history_management/history_manager.py b/dp3/history_management/history_manager.py index 15d78926..97f1f9e9 100644 --- a/dp3/history_management/history_manager.py +++ b/dp3/history_management/history_manager.py @@ -2,11 +2,10 @@ import json import logging import os -from datetime import datetime +from datetime import UTC, datetime from pathlib import Path -from typing import Optional -from pydantic import BaseModel, Extra +from pydantic import BaseModel, ConfigDict from dp3.common.attrspec import ( AttrSpecObservations, @@ -16,7 +15,7 @@ ) from dp3.common.callback_registrar import CallbackRegistrar from dp3.common.config import CronExpression, PlatformConfig -from dp3.common.types import UTC, DP3Encoder, ParsedTimedelta +from dp3.common.types import DP3Encoder, ParsedTimedelta from dp3.common.utils import entity_expired from dp3.database.database import DatabaseError, EntityDatabase @@ -46,10 +45,10 @@ class DPArchivationConfig(BaseModel): schedule: CronExpression older_than: ParsedTimedelta - archive_dir: Optional[str] = None + archive_dir: str | None = None -class HistoryManagerConfig(BaseModel, extra=Extra.forbid): +class HistoryManagerConfig(BaseModel): """Configuration for history manager. Attributes: @@ -60,6 +59,8 @@ class HistoryManagerConfig(BaseModel, extra=Extra.forbid): datapoint_archivation: Configuration for datapoint archivation. """ + model_config = ConfigDict(extra="forbid") + aggregation_schedule: CronExpression datapoint_cleaning_schedule: CronExpression mark_datapoints_schedule: CronExpression @@ -237,7 +238,7 @@ def _reformat_dp(dp): def _get_raw_dps_summary( self, before: datetime - ) -> tuple[Optional[datetime], Optional[datetime], int]: + ) -> tuple[datetime | None, datetime | None, int]: date_ranges = [] for etype in self.model_spec.entities: summary = self.db.get_archive_summary(etype, before=before) diff --git a/dp3/history_management/telemetry.py b/dp3/history_management/telemetry.py index f4d0f17e..20c84420 100644 --- a/dp3/history_management/telemetry.py +++ b/dp3/history_management/telemetry.py @@ -1,7 +1,7 @@ import logging import threading import time -from datetime import datetime +from datetime import UTC, datetime import requests from pymongo import ASCENDING, UpdateOne @@ -10,7 +10,6 @@ from dp3.common.config import PlatformConfig from dp3.common.datapoint import DataPointObservationsBase, DataPointTimeseriesBase from dp3.common.task import DataPointTask -from dp3.common.types import UTC from dp3.database.database import EntityDatabase @@ -43,7 +42,7 @@ def note_latest_src_timestamp(self, task: DataPointTask): """Note the latest timestamp of each source in the task""" latest_timestamps = {} for dp in task.data_points: - has_timestamp = isinstance(dp, (DataPointObservationsBase, DataPointTimeseriesBase)) + has_timestamp = isinstance(dp, DataPointObservationsBase | DataPointTimeseriesBase) if dp.src is None or not has_timestamp: continue latest_timestamp = dp.t2 or dp.t1 diff --git a/dp3/scripts/add_hashes.py b/dp3/scripts/add_hashes.py index 87571bac..7104e8f4 100755 --- a/dp3/scripts/add_hashes.py +++ b/dp3/scripts/add_hashes.py @@ -29,7 +29,7 @@ model_spec = ModelSpec(config.get("db_entities")) # Connect to database -connection_conf = MongoConfig.parse_obj(config.get("database", {})) +connection_conf = MongoConfig.model_validate(config.get("database", {})) client = EntityDatabase.connect(connection_conf) client.admin.command("ping") diff --git a/dp3/scripts/datapoint_log_converter.py b/dp3/scripts/datapoint_log_converter.py index b5086f4b..a9dd1ec4 100755 --- a/dp3/scripts/datapoint_log_converter.py +++ b/dp3/scripts/datapoint_log_converter.py @@ -7,7 +7,8 @@ import logging import os import re -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import pandas as pd from dateutil.parser import parse as parsetime diff --git a/dp3/scripts/dummy_sender.py b/dp3/scripts/dummy_sender.py index b6d8cb2f..d9e9c856 100755 --- a/dp3/scripts/dummy_sender.py +++ b/dp3/scripts/dummy_sender.py @@ -6,7 +6,7 @@ import os import time from argparse import ArgumentParser -from datetime import datetime, timezone +from datetime import UTC, datetime from itertools import islice from queue import Queue from threading import Event, Thread @@ -14,8 +14,6 @@ import pandas as pd import requests -UTC = timezone.utc - def get_valid_path(parser, arg): if not os.path.exists(arg): diff --git a/dp3/snapshots/snapshooter.py b/dp3/snapshots/snapshooter.py index 717524ee..826cc2a0 100644 --- a/dp3/snapshots/snapshooter.py +++ b/dp3/snapshots/snapshooter.py @@ -19,8 +19,9 @@ import logging from collections import defaultdict -from datetime import datetime -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from datetime import UTC, datetime +from typing import Any import pymongo.errors from event_count_logger import DummyEventGroup @@ -42,7 +43,7 @@ parse_eids_from_cache, task_context, ) -from dp3.common.types import UTC, EventGroupType +from dp3.common.types import EventGroupType from dp3.common.utils import get_func_name from dp3.database.database import EntityDatabase from dp3.snapshots.snapshot_hooks import ( @@ -69,7 +70,7 @@ def __init__( task_queue_writer: TaskQueueWriter, platform_config: PlatformConfig, scheduler: Scheduler, - elog: Optional[EventGroupType] = None, + elog: EventGroupType | None = None, ) -> None: self.log = logging.getLogger("SnapShooter") @@ -185,7 +186,7 @@ def register_timeseries_hook( def register_correlation_hook( self, - hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]], + hook: Callable[[str, dict, dict], None | list[DataPointTask]], entity_type: str, depends_on: list[list[str]], may_change: list[list[str]], @@ -543,7 +544,7 @@ def make_snapshot(self, task: Snapshot): self.db.update_metadata(task.time, metadata={"linked_finished": True}, worker_id=0) @staticmethod - def _remove_record_from_value(spec: AttrSpecType, value: Union[dict, list[dict]]): + def _remove_record_from_value(spec: AttrSpecType, value: dict | list[dict]): if spec.is_iterable: for link_dict in value: if "record" in link_dict: @@ -574,7 +575,7 @@ def extend_master_record(etype, master_record, new_tasks: list[DataPointTask]): for datapoint in task.data_points: if datapoint.etype != etype: continue - dp_dict = datapoint.dict(include={"v", "t1", "t2", "c"}) + dp_dict = datapoint.model_dump(include={"v", "t1", "t2", "c"}) if datapoint.attr in master_record: master_record[datapoint.attr].append() else: @@ -647,7 +648,7 @@ def get_linked_entity_ids(self, entity_type: str, current_values: dict) -> set[t @staticmethod def _get_link_entity_ids( - spec: AttrSpecType, link_value: Union[list[dict], dict] + spec: AttrSpecType, link_value: list[dict] | dict ) -> set[tuple[str, str]]: if spec.is_iterable: return {(spec.relation_to, v["eid"]) for v in link_value} @@ -664,7 +665,7 @@ def link_loaded_entities(self, loaded_entities: dict): entity[attr] = [] val_conf = entity[f"{attr}#c"] pruned_conf = [] - for v, conf in zip(val, val_conf): + for v, conf in zip(val, val_conf, strict=False): if self._keep_link(loaded_entities, attr_spec, v): self._link_record(loaded_entities, attr_spec, v) entity[attr].append(v) @@ -682,7 +683,7 @@ def link_loaded_entities(self, loaded_entities: dict): del entity[key] def _keep_link( - self, loaded_entities: dict, attr_spec: AttrSpecType, val: Union[dict, list[dict]] + self, loaded_entities: dict, attr_spec: AttrSpecType, val: dict | list[dict] ) -> bool: if self.config.keep_empty: return True @@ -693,7 +694,7 @@ def _keep_link( return loaded_entities.get((attr_spec.relation_to, val["eid"])) is not None @staticmethod - def _link_record(loaded_entities: dict, attr_spec: AttrSpecType, val: Union[dict, list[dict]]): + def _link_record(loaded_entities: dict, attr_spec: AttrSpecType, val: dict | list[dict]): if attr_spec.is_iterable: for link_dict in val: link_dict["record"] = loaded_entities.get( diff --git a/dp3/snapshots/snapshot_hooks.py b/dp3/snapshots/snapshot_hooks.py index dcf437f5..2ffe3c1e 100644 --- a/dp3/snapshots/snapshot_hooks.py +++ b/dp3/snapshots/snapshot_hooks.py @@ -4,9 +4,8 @@ import logging from collections import defaultdict -from collections.abc import Hashable +from collections.abc import Callable, Hashable from dataclasses import dataclass, field -from typing import Callable, Union from dp3.common.attrspec import AttrType from dp3.common.config import ModelSpec @@ -84,7 +83,7 @@ def __init__(self, log: logging.Logger, model_spec: ModelSpec, elog: EventGroupT def register( self, - hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]], + hook: Callable[[str, dict, dict], None | list[DataPointTask]], entity_type: str, depends_on: list[list[str]], may_change: list[list[str]], diff --git a/dp3/task_processing/task_executor.py b/dp3/task_processing/task_executor.py index 5703820d..ff8a487d 100644 --- a/dp3/task_processing/task_executor.py +++ b/dp3/task_processing/task_executor.py @@ -1,5 +1,5 @@ import logging -from typing import Callable +from collections.abc import Callable from event_count_logger import DummyEventGroup @@ -193,10 +193,10 @@ def refresh_on_entity_creation( for master_record in self.db.get_worker_master_records( worker_id, worker_cnt, etype, projection=projection ): - with task_context(self.model_spec): + with task_context(self.model_spec, allow_empty_data_point_task=True): task = DataPointTask(etype=etype, eid=master_record["_id"]) - self.log.debug(f"Refreshing {etype}/{task.eid}") - new_tasks += self._task_entity_hooks[task.etype].run_on_creation(task.eid, task) + self.log.debug(f"Refreshing {etype}/{task.eid}") + new_tasks += self._task_entity_hooks[task.etype].run_on_creation(task.eid, task) return new_tasks diff --git a/dp3/task_processing/task_hooks.py b/dp3/task_processing/task_hooks.py index 0cba01f0..527c1b72 100644 --- a/dp3/task_processing/task_hooks.py +++ b/dp3/task_processing/task_hooks.py @@ -1,5 +1,5 @@ import logging -from typing import Callable +from collections.abc import Callable from dp3.common.attrspec import AttrType from dp3.common.config import ModelSpec diff --git a/dp3/task_processing/task_queue.py b/dp3/task_processing/task_queue.py index 6ddd1718..f6f2f17a 100644 --- a/dp3/task_processing/task_queue.py +++ b/dp3/task_processing/task_queue.py @@ -36,7 +36,8 @@ import logging import threading import time -from typing import Callable, Union +from collections.abc import Callable +from typing import Literal import amqpstorm @@ -94,7 +95,7 @@ class RobustAMQPConnection: host, port, virtual_host, username, password """ - def __init__(self, rabbit_config: dict = None) -> None: + def __init__(self, rabbit_config: dict | None = None) -> None: rabbit_config = {} if rabbit_config is None else rabbit_config self.log = logging.getLogger("RobustAMQPConnection") self.conn_params = { @@ -104,8 +105,8 @@ def __init__(self, rabbit_config: dict = None) -> None: "username": rabbit_config.get("username", "guest"), "password": rabbit_config.get("password", "guest"), } - self.connection: amqpstorm.Connection = None - self.channel: amqpstorm.Channel = None + self.connection: amqpstorm.Connection | None = None + self.channel: amqpstorm.Channel | None = None self._connection_id = 0 def __del__(self): @@ -133,9 +134,10 @@ def connect(self) -> None: # This was a repeated attempt, print success message with ERROR level self.log.error("... it's OK now, we're successfully connected!") - self.channel = self.connection.channel() - self.channel.confirm_deliveries() - self.channel.basic.qos(PREFETCH_COUNT) + channel = self.connection.channel() + channel.confirm_deliveries() + channel.basic.qos(PREFETCH_COUNT) + self.channel = channel break except amqpstorm.AMQPError as e: sleep_time = RECONNECT_DELAYS[min(attempts, len(RECONNECT_DELAYS)) - 1] @@ -152,7 +154,7 @@ def disconnect(self) -> None: self.connection = None self.channel = None - def check_queue_existence(self, queue_name: str) -> bool: + def check_queue_existence(self, queue_name: str | None) -> bool: if queue_name is None: return True assert self.channel is not None, "not connected" @@ -191,10 +193,10 @@ def __init__( self, app_name: str, workers: int = 1, - rabbit_config: dict = None, - exchange: str = None, - priority_exchange: str = None, - parent_logger: logging.Logger = None, + rabbit_config: dict | None = None, + exchange: str | None = None, + priority_exchange: str | None = None, + parent_logger: logging.Logger | None = None, ) -> None: rabbit_config = {} if rabbit_config is None else rabbit_config assert isinstance(workers, int) and workers >= 1, "count of workers must be positive number" @@ -360,10 +362,10 @@ def __init__( parse_task: Callable[[str], Task], app_name: str, worker_index: int = 0, - rabbit_config: dict = None, - queue: str = None, - priority_queue: Union[str, bool] = None, - parent_logger: logging.Logger = None, + rabbit_config: dict | None = None, + queue: str | None = None, + priority_queue: str | Literal[False] | None = None, + parent_logger: logging.Logger | None = None, ) -> None: rabbit_config = {} if rabbit_config is None else rabbit_config assert callable(callback), "callback must be callable object" @@ -391,14 +393,14 @@ def __init__( priority_queue = DEFAULT_PRIORITY_QUEUE.format(app_name, worker_index) elif priority_queue is False: priority_queue = None - self.queue_name = queue - self.priority_queue_name = priority_queue + self.queue_name: str = queue + self.priority_queue_name: str | None = priority_queue self.worker_index = worker_index self.running = False - self._consuming_thread = None - self._processing_thread = None + self._consuming_thread: threading.Thread | None = None + self._processing_thread: threading.Thread | None = None # Receive messages into 2 temporary queues # (max length should be equal to prefetch_count set in RabbitMQReader) @@ -490,11 +492,12 @@ def ack(self, msg_tag: tuple[int, int]) -> bool: Returns: Whether the message was acknowledged successfully and can be processed further. """ - conn_id, msg_tag = msg_tag + assert self.channel is not None, "not connected" + conn_id, _delivery_tag = msg_tag if conn_id != self._connection_id: return False try: - self.channel.basic.ack(delivery_tag=msg_tag) + self.channel.basic.ack(delivery_tag=_delivery_tag) except amqpstorm.AMQPChannelError as why: self.log.error("Channel error while acknowledging message: %s", why) self.reconnect() @@ -503,6 +506,7 @@ def ack(self, msg_tag: tuple[int, int]) -> bool: def _consuming_thread_func(self): # Register consumers and start consuming loop, reconnect on error + assert self.channel is not None, "not connected" while self.running: try: # Register consumers on both queues @@ -587,8 +591,8 @@ def watchdog(self): Register to be called periodically by scheduler. """ - proc = self._processing_thread.is_alive() - cons = self._consuming_thread.is_alive() + proc = self._processing_thread is not None and self._processing_thread.is_alive() + cons = self._consuming_thread is not None and self._consuming_thread.is_alive() if not proc or not cons: self.log.error( @@ -599,7 +603,8 @@ def watchdog(self): self._stop_consuming_thread() self._stop_processing_thread() - self.channel.close() + if self.channel is not None: + self.channel.close() self.channel = None self.cache.clear() self.cache_pri.clear() @@ -609,16 +614,17 @@ def watchdog(self): def _stop_consuming_thread(self) -> None: if self._consuming_thread: - if self._consuming_thread.is_alive: + if self._consuming_thread.is_alive(): # if not connected, no problem with contextlib.suppress(amqpstorm.AMQPError): - self.channel.stop_consuming() + if self.channel is not None: + self.channel.stop_consuming() self._consuming_thread.join() self._consuming_thread = None def _stop_processing_thread(self) -> None: if self._processing_thread: - if self._processing_thread.is_alive: + if self._processing_thread.is_alive(): self.running = False # tell processing thread to stop self.cache_full.set() # break potential wait() for data self._processing_thread.join() diff --git a/dp3/template/app/docker/python/Dockerfile b/dp3/template/app/docker/python/Dockerfile index f01df3ba..81c6589f 100644 --- a/dp3/template/app/docker/python/Dockerfile +++ b/dp3/template/app/docker/python/Dockerfile @@ -1,7 +1,7 @@ # syntax=docker/dockerfile:1 # Base interpreter with installed requirements -FROM python:3.9-slim as base +FROM python:3.11-slim as base RUN apt-get update; apt-get install -y git # Install requirements diff --git a/dp3/testing/__init__.py b/dp3/testing/__init__.py new file mode 100644 index 00000000..350a6e9d --- /dev/null +++ b/dp3/testing/__init__.py @@ -0,0 +1,13 @@ +"""Testing helpers for DP3 applications.""" + +from dp3.testing.case import DP3ModuleTestCase +from dp3.testing.config import CONFIG_DIR_ENV, resolve_config_dir +from dp3.testing.registrar import HookRegistration, TestCallbackRegistrar + +__all__ = [ + "CONFIG_DIR_ENV", + "DP3ModuleTestCase", + "HookRegistration", + "TestCallbackRegistrar", + "resolve_config_dir", +] diff --git a/dp3/testing/assertions.py b/dp3/testing/assertions.py new file mode 100644 index 00000000..633a8263 --- /dev/null +++ b/dp3/testing/assertions.py @@ -0,0 +1,140 @@ +"""Assertion helpers for DP3 module tests.""" + +import unittest +from collections.abc import Iterable +from typing import Any + +from pydantic import BaseModel + +from dp3.common.datapoint import DataPointBase +from dp3.common.task import DataPointTask + +_UNSET = object() + + +class ModuleAssertions(unittest.TestCase): + """Partial-match assertions for module hook outputs.""" + + def assert_no_tasks(self, tasks: Iterable[DataPointTask]) -> None: + self.assertEqual([], list(tasks)) + + def assert_no_datapoints(self, tasks: Iterable[DataPointTask]) -> None: + self.assertEqual([], list(self.iter_datapoints(tasks))) + + def assert_task_emitted( # noqa: PLR0913 + self, + tasks: Iterable[DataPointTask], + *, + etype: Any = _UNSET, + eid: Any = _UNSET, + data_points: Any = _UNSET, + tags: Any = _UNSET, + ttl_tokens: Any = _UNSET, + delete: Any = _UNSET, + ) -> DataPointTask: + """Assert that a task matching the supplied ``DataPointTask`` fields was emitted.""" + expected = _selected_fields( + etype=etype, + eid=eid, + data_points=data_points, + tags=tags, + ttl_tokens=ttl_tokens, + delete=delete, + ) + task_list = list(tasks) + for task in task_list: + if self._partial_match(dump_value(task), expected): + return task + self.fail(f"No emitted task matched {expected!r}. Emitted tasks: {dump_value(task_list)!r}") + + def assert_datapoint( # noqa: PLR0913 + self, + tasks: Iterable[DataPointTask], + *, + etype: Any = _UNSET, + eid: Any = _UNSET, + attr: Any = _UNSET, + src: Any = _UNSET, + v: Any = _UNSET, + c: Any = _UNSET, + t1: Any = _UNSET, + t2: Any = _UNSET, + ) -> DataPointBase: + """Assert that a datapoint matching the supplied ``DataPointBase`` fields was emitted.""" + expected = _selected_fields( + etype=etype, + eid=eid, + attr=attr, + src=src, + v=v, + c=c, + t1=t1, + t2=t2, + ) + datapoints = list(self.iter_datapoints(tasks)) + for dp in datapoints: + if self._partial_match(dump_value(dp), expected): + return dp + self.fail( + f"No emitted datapoint matched {expected!r}. " + f"Emitted datapoints: {dump_value(datapoints)!r}" + ) + + def assert_record_contains(self, record: dict, **expected) -> None: + if not self._partial_match(record, expected): + self.fail(f"Record {record!r} does not contain expected values {expected!r}") + + def assert_record_attr(self, record: dict, attr: str, expected: Any) -> None: + if attr not in record: + self.fail(f"Record {record!r} does not contain attribute {attr!r}") + if not self._partial_match(record[attr], expected): + self.fail( + f"Record attribute {attr!r} value {record[attr]!r} " + f"does not match expected value {expected!r}" + ) + + def assert_record_unchanged(self, before: dict, after: dict) -> None: + self.assertEqual(dump_value(before), dump_value(after)) + + assertNoTasks = assert_no_tasks + assertNoDatapoints = assert_no_datapoints + assertTaskEmitted = assert_task_emitted + assertDatapoint = assert_datapoint + assertRecordContains = assert_record_contains + assertRecordAttr = assert_record_attr + assertRecordUnchanged = assert_record_unchanged + + @staticmethod + def iter_datapoints(tasks: Iterable[DataPointTask]) -> Iterable[DataPointBase]: + for task in tasks: + yield from task.data_points + + @classmethod + def _partial_match(cls, actual: Any, expected: Any) -> bool: + actual = dump_value(actual) + expected = dump_value(expected) + if isinstance(expected, dict): + if not isinstance(actual, dict): + return False + return all( + key in actual and cls._partial_match(actual[key], value) + for key, value in expected.items() + ) + return actual == expected + + +def _selected_fields(**fields: Any) -> dict[str, Any]: + return {key: value for key, value in fields.items() if value is not _UNSET} + + +def dump_value(value: Any) -> Any: + """Convert pydantic values recursively to plain Python containers.""" + if isinstance(value, BaseModel): + return value.model_dump() + if isinstance(value, list): + return [dump_value(item) for item in value] + if isinstance(value, tuple): + return tuple(dump_value(item) for item in value) + if isinstance(value, dict): + return {key: dump_value(item) for key, item in value.items()} + return value diff --git a/dp3/testing/case.py b/dp3/testing/case.py new file mode 100644 index 00000000..0ba85fc6 --- /dev/null +++ b/dp3/testing/case.py @@ -0,0 +1,337 @@ +"""unittest base class for DP3 secondary module tests.""" + +import copy +import unittest +from collections.abc import Callable, Iterable, Mapping, Sequence +from datetime import UTC, datetime +from typing import Any, Generic, TypeVar + +from dp3.common.attrspec import AttrType +from dp3.common.base_module import BaseModule +from dp3.common.config import HierarchicalDict, ModelSpec, PlatformConfig +from dp3.common.datapoint import DataPointBase +from dp3.common.task import DataPointTask, task_context +from dp3.common.utils import get_func_name +from dp3.testing.assertions import ModuleAssertions +from dp3.testing.config import ( + CONFIG_DIR_ENV, + build_model_spec, + build_platform_config, + get_module_config, + load_config, + resolve_config_dir, +) +from dp3.testing.registrar import HookRegistration, TestCallbackRegistrar + +ModuleT = TypeVar("ModuleT", bound=BaseModule) + + +class DP3ModuleTestCase(ModuleAssertions, unittest.TestCase, Generic[ModuleT]): + """Base class for unit tests of DP3 secondary modules. + + By default the app configuration directory is read from ``DP3_CONFIG_DIR``. Subclasses may set + ``config_dir`` explicitly when they need a fixed fixture config. + """ + + config_dir: str | None = None + config_env_var: str = CONFIG_DIR_ENV + module_class: type[ModuleT] + module_name: str | None = None + module_config: dict | None = None + app_name: str = "test" + process_index: int = 0 + num_processes: int = 1 + module: ModuleT + + def setUp(self) -> None: + super().setUp() + self.config_base_path = self.resolve_config_dir() + self.config = self.load_config() + self.model_spec = self.make_model_spec(self.config) + self.platform_config = self.make_platform_config() + self.registrar = self.make_registrar() + self.module = self.make_module(self.module_class, self.get_module_config(), self.registrar) + + def resolve_config_dir(self) -> str: + return resolve_config_dir(self.config_dir, self.config_env_var) + + def load_config(self) -> HierarchicalDict: + return load_config(self.config_base_path, self.config_env_var) + + def make_model_spec(self, config: HierarchicalDict) -> ModelSpec: + return build_model_spec(config) + + def make_platform_config(self) -> PlatformConfig: + return build_platform_config( + self.config, + self.model_spec, + self.config_base_path, + app_name=self.app_name, + process_index=self.process_index, + num_processes=self.num_processes, + env_var=self.config_env_var, + ) + + def get_module_config(self) -> dict: + if self.module_config is not None: + return copy.deepcopy(self.module_config) + return copy.deepcopy(get_module_config(self.config, self.get_module_name())) + + def get_module_name(self) -> str | None: + if self.module_name is not None: + return self.module_name + return self.module_class.__module__.split(".")[-1] + + def make_registrar(self) -> TestCallbackRegistrar: + return TestCallbackRegistrar(self.model_spec) + + def make_module( + self, + module_class: type[ModuleT], + module_config: dict[str, Any], + registrar: TestCallbackRegistrar, + ) -> ModuleT: + return module_class(self.platform_config, module_config, registrar) + + def make_task( + self, + etype: str, + eid: Any, + data_points: list[dict | DataPointBase] | None = None, + tags: list | None = None, + ttl_tokens: dict | None = None, + delete: bool = False, + ) -> DataPointTask: + with task_context(self.model_spec): + return DataPointTask( + etype=etype, + eid=eid, + data_points=data_points or [], + tags=tags or [], + ttl_tokens=ttl_tokens, + delete=delete, + ) + + def make_datapoint( + self, + etype: str, + eid: Any, + attr: str, + v: Any, + src: str = "test", + **fields, + ) -> DataPointBase: + task = self.make_task( + etype, + eid, + [dict({"etype": etype, "eid": eid, "attr": attr, "src": src, "v": v}, **fields)], + ) + return task.data_points[0] + + def make_plain_datapoint( + self, etype: str, eid: Any, attr: str, v: Any, src: str = "test", **fields + ) -> DataPointBase: + return self.make_datapoint(etype, eid, attr, v, src=src, **fields) + + def make_observation_datapoint( + self, + etype: str, + eid: Any, + attr: str, + v: Any, + src: str = "test", + t1: datetime | None = None, + t2: datetime | None = None, + c: float = 1.0, + **fields, + ) -> DataPointBase: + t1 = t1 or datetime.now(UTC) + data = {"t1": t1, "c": c, **fields} + if t2 is not None: + data["t2"] = t2 + return self.make_datapoint(etype, eid, attr, v, src=src, **data) + + def make_timeseries_datapoint( + self, + etype: str, + eid: Any, + attr: str, + v: Mapping[str, Sequence[Any]], + src: str = "test", + t1: datetime | None = None, + t2: datetime | None = None, + **fields, + ) -> DataPointBase: + """Create a validated timeseries datapoint. + + For regular timeseries attributes, ``t2`` is inferred when omitted by using the + attribute's configured ``time_step`` and the number of samples in ``v``: + ``t2 = t1 + len(series) * time_step``. For irregular timeseries, ``t1`` is inferred from + the first ``time`` value when omitted. For irregular-interval timeseries, ``t1`` is inferred + from the first ``time_first`` value when omitted. + """ + attr_spec = self.model_spec.attributes[etype, attr] + if attr_spec.t != AttrType.TIMESERIES: + raise ValueError(f"Attribute {etype}/{attr} is not a timeseries attribute.") + + values = dict(v) + t1 = t1 or self._infer_timeseries_t1(attr_spec, values) or datetime.now(UTC) + if t2 is None and attr_spec.timeseries_type == "regular": + time_step = attr_spec.timeseries_params.time_step + if time_step is None: + raise ValueError(f"Regular timeseries attribute {etype}/{attr} has no time_step.") + t2 = t1 + self._timeseries_length(values) * time_step + + data = {"t1": t1, **fields} + if t2 is not None: + data["t2"] = t2 + return self.make_datapoint(etype, eid, attr, values, src=src, **data) + + @staticmethod + def _infer_timeseries_t1(attr_spec, values: Mapping[str, Sequence[Any]]) -> datetime | None: + if attr_spec.timeseries_type == "irregular" and values.get("time"): + return values["time"][0] + if attr_spec.timeseries_type == "irregular_intervals" and values.get("time_first"): + return values["time_first"][0] + return None + + @staticmethod + def _timeseries_length(values: Mapping[str, Sequence[Any]]) -> int: + try: + return len(next(iter(values.values()))) + except StopIteration as e: + raise ValueError("Timeseries datapoint values cannot be empty.") from e + + def run_task_hooks(self, hook_type: str, task: DataPointTask) -> None: + self.registrar.run_task_hooks(hook_type, task) + + def run_allow_entity_creation( + self, entity: str, eid: Any, task: DataPointTask | None = None + ) -> bool: + task = task or self._make_synthetic_task(entity, eid) + return self.registrar.run_allow_entity_creation(entity, eid, task) + + def run_on_entity_creation( + self, entity: str, eid: Any, task: DataPointTask | None = None + ) -> list[DataPointTask]: + task = task or self._make_synthetic_task(entity, eid) + return self.registrar.run_on_entity_creation(entity, eid, task) + + def _make_synthetic_task(self, etype: str, eid: Any) -> DataPointTask: + with task_context(self.model_spec, allow_empty_data_point_task=True): + return DataPointTask(etype=etype, eid=eid) + + def run_on_new_attr(self, entity: str, attr: str, eid: Any, dp: DataPointBase): + return self.registrar.run_on_new_attr(entity, attr, eid, dp) + + def run_correlation_hooks( + self, + entity_type: str, + record: dict, + master_record: dict | None = None, + ) -> list[DataPointTask]: + return self.registrar.run_correlation_hooks(entity_type, record, master_record) + + def run_periodic_update( + self, entity_type: str, eid: Any, master_record: dict, hook_id: str | None = None + ) -> list[DataPointTask]: + return self.registrar.run_periodic_update(entity_type, eid, master_record, hook_id) + + def run_periodic_eid_update( + self, entity_type: str, eid: Any, hook_id: str | None = None + ) -> list[DataPointTask]: + return self.registrar.run_periodic_eid_update(entity_type, eid, hook_id) + + def run_scheduler_job(self, job: int | str | Callable | HookRegistration): + return self.registrar.run_scheduler_job(job) + + def registered(self, kind: str | None = None, **fields) -> list[HookRegistration]: + """Return registrations matching ``kind`` and the supplied registration fields.""" + return [ + registration + for registration in self.registrar.registrations + if self._registration_matches(registration, kind, fields) + ] + + def assert_registered(self, kind: str, **fields) -> HookRegistration: + """Assert that at least one callback registration matches the supplied fields.""" + matches = self.registered(kind, **fields) + if not matches: + self.fail( + f"No registration matched kind={kind!r}, fields={fields!r}. " + f"Registered callbacks: {self.registrar.registrations!r}" + ) + return matches[0] + + def assert_registered_once(self, kind: str, **fields) -> HookRegistration: + """Assert that exactly one callback registration matches the supplied fields.""" + matches = self.registered(kind, **fields) + if len(matches) != 1: + self.fail( + f"Expected one registration matching kind={kind!r}, fields={fields!r}; " + f"found {len(matches)}: {matches!r}" + ) + return matches[0] + + def assert_registered_attrs( + self, + entity: str, + expected_attrs: Iterable[str], + *, + kind: str = "on_new_attr", + exact: bool = True, + ) -> list[HookRegistration]: + """Assert that attribute hook registrations exist for the supplied entity attributes.""" + expected = set(expected_attrs) + matches = self.registered(kind, entity=entity) + actual = {registration.attr for registration in matches if registration.attr is not None} + if exact: + self.assertEqual(expected, actual) + else: + missing = expected - actual + if missing: + self.fail(f"Missing registrations for attributes: {sorted(missing)!r}") + return [registration for registration in matches if registration.attr in expected] + + def assert_scheduler_registered(self, **fields) -> HookRegistration: + """Assert that at least one scheduler callback registration matches the supplied fields.""" + return self.assert_registered("scheduler", **fields) + + assertRegistered = assert_registered + assertRegisteredOnce = assert_registered_once + assertRegisteredAttrs = assert_registered_attrs + assertSchedulerRegistered = assert_scheduler_registered + + def _registration_matches( + self, registration: HookRegistration, kind: str | None, fields: dict[str, Any] + ) -> bool: + if kind is not None and registration.kind != kind: + return False + for key, expected in fields.items(): + found, actual = _registration_field(registration, key) + if not found: + return False + if key in {"hook", "func"} and isinstance(expected, str): + if not _callable_name_matches(actual, expected): + return False + elif not self._partial_match(actual, expected): + return False + return True + + +def _registration_field(registration: HookRegistration, key: str) -> tuple[bool, Any]: + if key == "func": + return True, registration.hook + if hasattr(registration, key): + return True, getattr(registration, key) + if key in registration.extra: + return True, registration.extra[key] + schedule = registration.extra.get("schedule", {}) + if key in schedule: + return True, schedule[key] + return False, None + + +def _callable_name_matches(func: Callable, expected: str) -> bool: + func_name = get_func_name(func) + return func_name == expected or func_name.endswith(f".{expected}") diff --git a/dp3/testing/config.py b/dp3/testing/config.py new file mode 100644 index 00000000..a6b27624 --- /dev/null +++ b/dp3/testing/config.py @@ -0,0 +1,63 @@ +"""Configuration helpers for DP3 module tests.""" + +import os +from contextlib import suppress + +from dp3.common.config import HierarchicalDict, ModelSpec, PlatformConfig, read_config_dir + +CONFIG_DIR_ENV = "DP3_CONFIG_DIR" + + +def resolve_config_dir(config_dir: str | None = None, env_var: str = CONFIG_DIR_ENV) -> str: + """Return an absolute DP3 config directory path. + + Explicit ``config_dir`` values take precedence. If no explicit path is supplied, the path is + read from ``env_var``. + """ + resolved = config_dir or os.environ.get(env_var) + if not resolved: + raise ValueError( + f"DP3 module tests require a config directory. Set {env_var} or pass " + "config_dir explicitly." + ) + return os.path.abspath(resolved) + + +def load_config(config_dir: str | None = None, env_var: str = CONFIG_DIR_ENV) -> HierarchicalDict: + """Load a DP3 config directory for module tests.""" + return read_config_dir(resolve_config_dir(config_dir, env_var), recursive=True) + + +def build_model_spec(config: HierarchicalDict) -> ModelSpec: + """Build a model specification from loaded DP3 configuration.""" + return ModelSpec(config.get("db_entities")) + + +def build_platform_config( + config: HierarchicalDict, + model_spec: ModelSpec, + config_dir: str | None = None, + *, + app_name: str = "test", + process_index: int = 0, + num_processes: int = 1, + env_var: str = CONFIG_DIR_ENV, +) -> PlatformConfig: + """Build the minimal platform config needed by secondary module unit tests.""" + with suppress(Exception): + num_processes = config.get("processing_core.worker_processes", num_processes) + return PlatformConfig( + app_name=app_name, + config_base_path=resolve_config_dir(config_dir, env_var), + config=config, + model_spec=model_spec, + process_index=process_index, + num_processes=num_processes, + ) + + +def get_module_config(config: HierarchicalDict, module_name: str | None) -> dict: + """Return module-specific config from loaded app config.""" + if module_name is None: + return {} + return config.get(f"modules.{module_name}", {}) diff --git a/dp3/testing/registrar.py b/dp3/testing/registrar.py new file mode 100644 index 00000000..e8df2417 --- /dev/null +++ b/dp3/testing/registrar.py @@ -0,0 +1,484 @@ +"""Test callback registrar for DP3 secondary modules.""" + +import copy +import logging +import warnings +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass, field +from functools import wraps +from typing import Any + +from event_count_logger import DummyEventGroup + +from dp3.common.attrspec import AttrType +from dp3.common.config import ModelSpec +from dp3.common.datapoint import DataPointBase +from dp3.common.task import DataPointTask, task_context +from dp3.common.utils import get_func_name +from dp3.snapshots.snapshot_hooks import ( + SnapshotCorrelationHookContainer, + SnapshotTimeseriesHookContainer, +) + + +@dataclass +class HookRegistration: + """Captured callback registration made by a secondary module.""" + + kind: str + hook: Callable + entity: str | None = None + attr: str | None = None + hook_type: str | None = None + hook_id: str | None = None + entity_type: str | None = None + attr_type: str | None = None + depends_on: list[list[str]] = field(default_factory=list) + may_change: list[list[str]] = field(default_factory=list) + refresh: Any = None + period: Any = None + deprecated: bool = False + extra: dict[str, Any] = field(default_factory=dict) + + +def _drop_master_for_test(hook: Callable[[str, dict], Any]) -> Callable[[str, dict, dict], Any]: + @wraps(hook) + def wrapped(entity_type: str, record: dict, _master_record: dict): + return hook(entity_type, record) + + return wrapped + + +class TestCallbackRegistrar: + """Callback registrar implementation for module unit tests.""" + + attr_spec_t_to_on_attr = { + AttrType.PLAIN: "on_new_plain", + AttrType.OBSERVATIONS: "on_new_observation", + AttrType.TIMESERIES: "on_new_ts_chunk", + } + + def __init__(self, model_spec: ModelSpec, log: logging.Logger | None = None): + self.model_spec = model_spec + self.log = log or logging.getLogger(self.__class__.__name__) + self.registrations: list[HookRegistration] = [] + self._task_hooks: defaultdict[str, list[HookRegistration]] = defaultdict(list) + self._allow_creation_hooks: defaultdict[str, list[HookRegistration]] = defaultdict(list) + self._on_creation_hooks: defaultdict[str, list[HookRegistration]] = defaultdict(list) + self._attr_hooks: defaultdict[tuple[str, str], list[HookRegistration]] = defaultdict(list) + self._snapshot_init_hooks: list[HookRegistration] = [] + self._snapshot_finalize_hooks: list[HookRegistration] = [] + self._periodic_record_hooks: defaultdict[tuple[str, str], list[HookRegistration]] = ( + defaultdict(list) + ) + self._periodic_eid_hooks: defaultdict[tuple[str, str], list[HookRegistration]] = ( + defaultdict(list) + ) + self._scheduler_jobs: list[HookRegistration] = [] + + self._correlation_hooks = SnapshotCorrelationHookContainer( + self.log, model_spec, DummyEventGroup() + ) + self._timeseries_hooks = SnapshotTimeseriesHookContainer( + self.log, model_spec, DummyEventGroup() + ) + + def scheduler_register( # noqa: PLR0913 + self, + func: Callable, + *, + func_args: list | tuple = None, + func_kwargs: dict = None, + year: int | str = None, + month: int | str = None, + day: int | str = None, + week: int | str = None, + day_of_week: int | str = None, + hour: int | str = None, + minute: int | str = None, + second: int | str = None, + timezone: str = "UTC", + misfire_grace_time: int = 1, + ) -> int: + schedule = { + "year": year, + "month": month, + "day": day, + "week": week, + "day_of_week": day_of_week, + "hour": hour, + "minute": minute, + "second": second, + "timezone": timezone, + "misfire_grace_time": misfire_grace_time, + } + reg = HookRegistration( + kind="scheduler", + hook=func, + extra={ + "func_args": list(func_args or []), + "func_kwargs": dict(func_kwargs or {}), + "schedule": schedule, + }, + ) + self._record(reg) + self._scheduler_jobs.append(reg) + return len(self._scheduler_jobs) - 1 + + def register_task_hook(self, hook_type: str, hook: Callable): + if hook_type != "on_task_start": + raise ValueError(f"Hook type '{hook_type}' doesn't exist.") + reg = HookRegistration(kind="task", hook_type=hook_type, hook=hook) + self._record(reg) + self._task_hooks[hook_type].append(reg) + + def register_allow_entity_creation_hook(self, hook: Callable, entity: str): + self._validate_entity(entity) + reg = HookRegistration(kind="allow_entity_creation", entity=entity, hook=hook) + self._record(reg) + self._allow_creation_hooks[entity].append(reg) + + def register_on_entity_creation_hook( + self, + hook: Callable, + entity: str, + refresh: Any = None, + may_change: list[list[str]] = None, + ): + self._validate_entity(entity) + if refresh is not None and may_change is None: + raise ValueError("'may_change' must be specified if 'refresh' is specified") + reg = HookRegistration( + kind="on_entity_creation", + entity=entity, + hook=hook, + refresh=refresh, + may_change=copy.deepcopy(may_change or []), + ) + self._record(reg) + self._on_creation_hooks[entity].append(reg) + + def register_on_new_attr_hook( + self, + hook: Callable, + entity: str, + attr: str, + refresh: Any = None, + may_change: list[list[str]] = None, + ): + hook_type = self._hook_type_for_attr(entity, attr) + if refresh is not None and may_change is None: + raise ValueError("'may_change' must be specified if 'refresh' is specified") + reg = HookRegistration( + kind="on_new_attr", + hook_type=hook_type, + entity=entity, + attr=attr, + hook=hook, + refresh=refresh, + may_change=copy.deepcopy(may_change or []), + ) + self._record(reg) + self._attr_hooks[entity, attr].append(reg) + + def register_entity_hook(self, hook_type: str, hook: Callable, entity: str): + warnings.warn( + "register_entity_hook() is deprecated; use " + "register_allow_entity_creation_hook() or register_on_entity_creation_hook().", + DeprecationWarning, + stacklevel=2, + ) + self._validate_entity(entity) + if hook_type == "allow_entity_creation": + reg = HookRegistration( + kind="allow_entity_creation", entity=entity, hook=hook, deprecated=True + ) + self._record(reg) + self._allow_creation_hooks[entity].append(reg) + return + if hook_type == "on_entity_creation": + reg = HookRegistration( + kind="on_entity_creation", entity=entity, hook=hook, deprecated=True + ) + self._record(reg) + self._on_creation_hooks[entity].append(reg) + return + raise ValueError(f"Hook type '{hook_type}' doesn't exist.") + + def register_attr_hook(self, hook_type: str, hook: Callable, entity: str, attr: str): + warnings.warn( + "register_attr_hook() is deprecated; use register_on_new_attr_hook().", + DeprecationWarning, + stacklevel=2, + ) + expected_hook_type = self._hook_type_for_attr(entity, attr) + if hook_type != expected_hook_type: + raise ValueError(f"Hook type '{hook_type}' doesn't exist for {entity}/{attr}.") + reg = HookRegistration( + kind="on_new_attr", + hook_type=hook_type, + entity=entity, + attr=attr, + hook=hook, + deprecated=True, + ) + self._record(reg) + self._attr_hooks[entity, attr].append(reg) + + def register_timeseries_hook(self, hook: Callable, entity_type: str, attr_type: str): + self._timeseries_hooks.register(hook, entity_type, attr_type) + reg = HookRegistration( + kind="timeseries", hook=hook, entity_type=entity_type, attr_type=attr_type + ) + self._record(reg) + + def register_correlation_hook( + self, + hook: Callable, + entity_type: str, + depends_on: list[list[str]], + may_change: list[list[str]], + ): + wrapped = _drop_master_for_test(hook) + hook_id = self._correlation_hooks.register(wrapped, entity_type, depends_on, may_change) + reg = HookRegistration( + kind="correlation", + hook=hook, + hook_id=hook_id, + entity_type=entity_type, + depends_on=copy.deepcopy(depends_on), + may_change=copy.deepcopy(may_change), + ) + self._record(reg) + + def register_correlation_hook_with_master_record( + self, + hook: Callable, + entity_type: str, + depends_on: list[list[str]], + may_change: list[list[str]], + ): + hook_id = self._correlation_hooks.register(hook, entity_type, depends_on, may_change) + reg = HookRegistration( + kind="correlation_with_master_record", + hook=hook, + hook_id=hook_id, + entity_type=entity_type, + depends_on=copy.deepcopy(depends_on), + may_change=copy.deepcopy(may_change), + ) + self._record(reg) + + def register_snapshot_init_hook(self, hook: Callable): + reg = HookRegistration(kind="snapshot_init", hook=hook) + self._record(reg) + self._snapshot_init_hooks.append(reg) + + def register_snapshot_finalize_hook(self, hook: Callable): + reg = HookRegistration(kind="snapshot_finalize", hook=hook) + self._record(reg) + self._snapshot_finalize_hooks.append(reg) + + def register_periodic_update_hook( + self, hook: Callable, hook_id: str, entity_type: str, period: Any + ): + self._validate_entity(entity_type) + reg = HookRegistration( + kind="periodic_update", + hook=hook, + hook_id=hook_id, + entity_type=entity_type, + period=period, + ) + self._record(reg) + self._periodic_record_hooks[entity_type, hook_id].append(reg) + + def register_periodic_eid_update_hook( + self, hook: Callable, hook_id: str, entity_type: str, period: Any + ): + self._validate_entity(entity_type) + reg = HookRegistration( + kind="periodic_eid_update", + hook=hook, + hook_id=hook_id, + entity_type=entity_type, + period=period, + ) + self._record(reg) + self._periodic_eid_hooks[entity_type, hook_id].append(reg) + + def run_task_hooks(self, hook_type: str, task: DataPointTask) -> None: + for reg in self._task_hooks[hook_type]: + reg.hook(task) + + def run_allow_entity_creation(self, entity: str, eid: Any, task: DataPointTask) -> bool: + return all(reg.hook(eid, task) for reg in self._allow_creation_hooks[entity]) + + def run_on_entity_creation( + self, entity: str, eid: Any, task: DataPointTask + ) -> list[DataPointTask]: + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for reg in self._on_creation_hooks[entity]: + hook_tasks = reg.hook(eid, task) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def run_on_new_attr(self, entity: str, attr: str, eid: Any, dp: DataPointBase): + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for reg in self._attr_hooks[entity, attr]: + hook_tasks = reg.hook(eid, dp) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def run_timeseries_hook( + self, entity_type: str, attr_type: str, attr_history: list[dict] + ) -> list[DataPointTask]: + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for hook in self._timeseries_hooks._hooks[entity_type, attr_type]: + hook_tasks = hook(entity_type, attr_type, attr_history) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def run_correlation_hooks( + self, + entity_type: str, + record: dict, + master_record: dict | None = None, + ) -> list[DataPointTask]: + eid = self._assert_record_eid(record) + return self.run_correlation_hooks_for_entities( + {(entity_type, eid): record}, {(entity_type, eid): master_record or {}} + ) + + def run_correlation_hooks_for_entities( + self, entities: dict[tuple[str, Any], dict], master_records: dict | None = None + ) -> list[DataPointTask]: + master_records = master_records or {} + for entity_type, _ in entities: + self._validate_entity(entity_type) + for record in entities.values(): + self._assert_record_eid(record) + + entity_types = {etype for etype, _ in entities} + hook_subset = [ + (hook_id, hook, etype) + for etype in entity_types + for hook_id, hook in self._correlation_hooks._hooks[etype] + ] + topological_order = self._correlation_hooks._dependency_graph.topological_order + hook_subset.sort(key=lambda item: topological_order.index(item[0])) + + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for _, hook, etype in hook_subset: + for entity_key, record in entities.items(): + if entity_key[0] != etype: + continue + hook_tasks = hook(etype, record, master_records.get(entity_key, {})) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def run_snapshot_init_hooks(self) -> list[DataPointTask]: + return self._run_no_arg_hooks(self._snapshot_init_hooks) + + def run_snapshot_finalize_hooks(self) -> list[DataPointTask]: + return self._run_no_arg_hooks(self._snapshot_finalize_hooks) + + def run_periodic_update( + self, + entity_type: str, + eid: Any, + master_record: dict, + hook_id: str | None = None, + ) -> list[DataPointTask]: + hooks = self._matching_update_hooks(self._periodic_record_hooks, entity_type, hook_id) + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for reg in hooks: + hook_tasks = reg.hook(entity_type, eid, master_record) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def run_periodic_eid_update( + self, entity_type: str, eid: Any, hook_id: str | None = None + ) -> list[DataPointTask]: + hooks = self._matching_update_hooks(self._periodic_eid_hooks, entity_type, hook_id) + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for reg in hooks: + hook_tasks = reg.hook(entity_type, eid) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def get_scheduler_job(self, job: int | str | Callable | HookRegistration) -> HookRegistration: + """Return a registered scheduler job by index, callable, or callable name.""" + if isinstance(job, int): + return self._scheduler_jobs[job] + if isinstance(job, HookRegistration): + if job.kind != "scheduler": + raise ValueError(f"Registration kind '{job.kind}' is not a scheduler job.") + return job + + matches = [reg for reg in self._scheduler_jobs if _callable_matches(reg.hook, job)] + if not matches: + raise ValueError(f"No scheduler job matches {job!r}.") + if len(matches) > 1: + raise ValueError(f"Multiple scheduler jobs match {job!r}.") + return matches[0] + + def run_scheduler_job(self, job: int | str | Callable | HookRegistration): + reg = self.get_scheduler_job(job) + return reg.hook(*reg.extra["func_args"], **reg.extra["func_kwargs"]) + + def _record(self, registration: HookRegistration) -> None: + self.registrations.append(registration) + + def _validate_entity(self, entity: str) -> None: + if entity not in self.model_spec.entities: + raise ValueError(f"Entity '{entity}' does not exist.") + + def _hook_type_for_attr(self, entity: str, attr: str) -> str: + try: + return self.attr_spec_t_to_on_attr[self.model_spec.attributes[entity, attr].t] + except KeyError as e: + raise ValueError( + f"Cannot register hook for attribute {entity}/{attr}, are you sure it exists?" + ) from e + + @staticmethod + def _assert_record_eid(record: dict) -> Any: + if "eid" not in record: + raise ValueError("Correlation hook records must contain an 'eid' field.") + return record["eid"] + + def _run_no_arg_hooks(self, hooks: list[HookRegistration]) -> list[DataPointTask]: + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for reg in hooks: + hook_tasks = reg.hook() + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + @staticmethod + def _matching_update_hooks(hooks: dict, entity_type: str, hook_id: str | None): + if hook_id is not None: + return list(hooks[entity_type, hook_id]) + return [reg for (etype, _), regs in hooks.items() if etype == entity_type for reg in regs] + + +def _callable_matches(func: Callable, expected: str | Callable) -> bool: + if callable(expected): + return func == expected + func_name = get_func_name(func) + return func_name == expected or func_name.endswith(f".{expected}") diff --git a/mkdocs.yml b/mkdocs.yml index 44ae7b21..32695963 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,6 +49,7 @@ nav: - Add an attribute: howto/add-attribute.md - Add an input module: howto/add-input.md - Add a secondary module: howto/add-module.md + - Test a secondary module: howto/test-module.md - Deploy an app: howto/deploy-app.md - Develop DP3 itself: howto/develop-dp3.md - Extend Docs: howto/extending.md diff --git a/pyproject.toml b/pyproject.toml index 0cd5d32c..f7a78460 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", "Intended Audience :: Developers", ] -requires-python = ">=3.9" +requires-python = ">=3.11" dynamic = ["version", "dependencies", "optional-dependencies"] [tool.setuptools_scm] @@ -64,13 +64,13 @@ scripts = { file = ["requirements.scripts.txt"] } ### Black Formatting ################################################################### [tool.black] -target-version = ["py39"] +target-version = ["py311"] line-length = 100 extend-exclude = "/(install|docker)/" ### Ruff Code Linting ################################################################## [tool.ruff] -target-version = "py39" +target-version = "py311" extend-exclude = ["install", "docker"] line-length = 100 show-fixes = true diff --git a/requirements.scripts.txt b/requirements.scripts.txt index 74983ca4..3f4736c5 100644 --- a/requirements.scripts.txt +++ b/requirements.scripts.txt @@ -1,2 +1,2 @@ numpy>=1.23.0 -pandas~=1.4.3 +pandas~=2.2 diff --git a/requirements.txt b/requirements.txt index 2809ff2d..faa473d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -AMQPStorm~=2.7.2 -apscheduler~=3.10.0 -argcomplete~=3.6.0 +AMQPStorm~=2.7 +apscheduler~=3.10 +argcomplete~=3.6 event-count-logger>=1.1 fastapi>=0.109.1 pydantic>=2.4.0 -pymongo~=4.6.3 +pymongo~=4.6 python-dateutil~=2.8 pyyaml~=6.0 -requests~=2.32.0 +requests~=2.32 uvicorn>=0.22.0 diff --git a/tests/test_api/common.py b/tests/test_api/common.py index 9f766fb0..df1b9302 100644 --- a/tests/test_api/common.py +++ b/tests/test_api/common.py @@ -3,7 +3,8 @@ import sys import time import unittest -from typing import Callable, TypeVar +from collections.abc import Callable +from typing import TypeVar import requests from pydantic import BaseModel diff --git a/tests/test_api/test_01_datapoints.py b/tests/test_api/test_01_datapoints.py index c7e73c36..4e92ddc5 100644 --- a/tests/test_api/test_01_datapoints.py +++ b/tests/test_api/test_01_datapoints.py @@ -1,13 +1,11 @@ import json import sys -from datetime import datetime +from datetime import UTC, datetime from typing import Any import common from common import ACCEPTED_ERROR_CODES -from dp3.common.types import UTC - class PushDatapoints(common.APITest): def test_invalid_payload(self): diff --git a/tests/test_api/test_get_entity_eid_data.py b/tests/test_api/test_get_entity_eid_data.py index b7e1ac9e..818273b8 100644 --- a/tests/test_api/test_get_entity_eid_data.py +++ b/tests/test_api/test_get_entity_eid_data.py @@ -1,11 +1,10 @@ import sys -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta import common from pydantic import RootModel from dp3.api.internal.entity_response_models import EntityEidData, EntityEidMasterRecord -from dp3.common.types import UTC DATAPOINT_COUNT = 6 diff --git a/tests/test_api/test_raw.py b/tests/test_api/test_raw.py index 1b9ca919..2c9a5170 100644 --- a/tests/test_api/test_raw.py +++ b/tests/test_api/test_raw.py @@ -1,11 +1,11 @@ import datetime import json import sys +from datetime import UTC import common from dp3.api.internal.entity_response_models import EntityRawDataPage -from dp3.common.types import UTC class RawDatapointsIntegration(common.APITest): diff --git a/tests/test_api/test_snapshots.py b/tests/test_api/test_snapshots.py index d8fd82e4..f31a38b3 100644 --- a/tests/test_api/test_snapshots.py +++ b/tests/test_api/test_snapshots.py @@ -1,12 +1,12 @@ import datetime import json import sys +from datetime import UTC from time import sleep import common from dp3.api.internal.entity_response_models import EntityEidData, EntityEidSnapshots -from dp3.common.types import UTC class SnapshotIntegration(common.APITest): diff --git a/tests/test_api/test_telemetry.py b/tests/test_api/test_telemetry.py index 171079b6..71e1c480 100644 --- a/tests/test_api/test_telemetry.py +++ b/tests/test_api/test_telemetry.py @@ -1,12 +1,11 @@ import datetime import json import sys +from datetime import UTC from time import sleep import common -from dp3.common.types import UTC - class TelemetryEndpoints(common.APITest): @classmethod diff --git a/tests/test_common/test_magic.py b/tests/test_common/test_magic.py index 70f658d8..6bf12925 100644 --- a/tests/test_common/test_magic.py +++ b/tests/test_common/test_magic.py @@ -1,7 +1,7 @@ """Test the search & replace functionality for snapshot generic filter endpoint""" import unittest -from datetime import datetime, timezone +from datetime import UTC, datetime from ipaddress import IPv4Address, IPv6Address from bson import Binary @@ -34,17 +34,17 @@ def test_replace_int(self): def test_replace_date(self): query = {"date_attr": "$$Date{2021-01-01T00:00:00Z}"} - expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)} + expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, tzinfo=UTC)} self.assertEqual(search_and_replace(query), expected) def test_replace_date_ts(self): query = {"date_attr": "$$DateTs{1609459200}"} - expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc)} + expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, tzinfo=UTC)} self.assertEqual(search_and_replace(query), expected) # Test with float value query = {"date_attr": "$$DateTs{1609459200.5}"} - expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, 500000, tzinfo=timezone.utc)} + expected = {"date_attr": datetime(2021, 1, 1, 0, 0, 0, 500000, tzinfo=UTC)} self.assertEqual(search_and_replace(query), expected) def test_replace_ipv4_prefix(self): diff --git a/tests/test_common/test_module_testing.py b/tests/test_common/test_module_testing.py new file mode 100644 index 00000000..e90652b0 --- /dev/null +++ b/tests/test_common/test_module_testing.py @@ -0,0 +1,227 @@ +import copy +import os +import warnings +from datetime import UTC, datetime, timedelta + +from pydantic import ValidationError + +from dp3.common.base_module import BaseModule +from dp3.common.task import DataPointTask, task_context +from dp3.testing import DP3ModuleTestCase + + +class SampleModule(BaseModule): + def __init__(self, platform_config, module_config, registrar): + super().__init__(platform_config, module_config, registrar) + registrar.register_allow_entity_creation_hook(self.allow_create, "test_entity_type") + registrar.register_on_entity_creation_hook(self.on_create, "test_entity_type") + registrar.register_on_new_attr_hook(self.on_string, "test_entity_type", "test_attr_string") + registrar.register_correlation_hook( + self.copy_string, + "test_entity_type", + depends_on=[["test_attr_string"]], + may_change=[["test_attr_int"]], + ) + registrar.register_correlation_hook_with_master_record( + self.copy_master_float, + "test_entity_type", + depends_on=[], + may_change=[["test_attr_float"]], + ) + registrar.register_periodic_update_hook( + self.periodic_record, "sample", "test_entity_type", "1d" + ) + registrar.scheduler_register(self.reload, minute="*/5") + + def load_config(self, _config, module_config): + self.prefix = module_config.get("prefix", "created") + self.reload_count = 0 + + def allow_create(self, eid, _task): + return eid != "deny" + + def on_create(self, eid, _task): + return [ + DataPointTask( + etype="test_entity_type", + eid=eid, + data_points=[ + { + "etype": "test_entity_type", + "eid": eid, + "attr": "test_attr_string", + "src": "secondary/sample", + "v": f"{self.prefix}:{eid}", + } + ], + ) + ] + + def on_string(self, eid, dp): + return [ + DataPointTask( + etype="test_entity_type", + eid=eid, + data_points=[ + { + "etype": "test_entity_type", + "eid": eid, + "attr": "test_attr_int", + "src": "secondary/sample", + "v": len(dp.v), + } + ], + ) + ] + + def copy_string(self, _entity_type, record): + record["test_attr_int"] = len(record["test_attr_string"]) + + def copy_master_float(self, _entity_type, record, master_record): + record["test_attr_float"] = master_record["test_attr_float"]["v"] + + def periodic_record(self, entity_type, eid, master_record): + return [ + DataPointTask( + etype=entity_type, + eid=eid, + data_points=[ + { + "etype": entity_type, + "eid": eid, + "attr": "test_attr_string", + "src": "secondary/sample", + "v": master_record["test_attr_string"]["v"], + } + ], + ) + ] + + def reload(self): + self.reload_count += 1 + + +class DeprecatedHookModule(BaseModule): + def __init__(self, platform_config, module_config, registrar): + super().__init__(platform_config, module_config, registrar) + registrar.register_entity_hook("on_entity_creation", self.on_create, "test_entity_type") + registrar.register_attr_hook( + "on_new_plain", self.on_string, "test_entity_type", "test_attr_string" + ) + + def on_create(self, _eid, _task): + return [] + + def on_string(self, _eid, _dp): + return [] + + +TEST_CONFIG_DIR = os.path.join(os.path.dirname(__file__), "..", "test_config") + + +class TestDP3ModuleTestCase(DP3ModuleTestCase): + config_dir = TEST_CONFIG_DIR + module_class = SampleModule + module_config = {"prefix": "hello"} + + def test_entity_creation_hook_and_partial_datapoint_assertion(self): + self.assertTrue(self.run_allow_entity_creation("test_entity_type", "ok")) + self.assertFalse(self.run_allow_entity_creation("test_entity_type", "deny")) + + tasks = self.run_on_entity_creation("test_entity_type", "ok") + + self.assertDatapoint( + tasks, + etype="test_entity_type", + eid="ok", + attr="test_attr_string", + v="hello:ok", + ) + + def test_attribute_hook(self): + dp = self.make_plain_datapoint("test_entity_type", "e1", "test_attr_string", "abcd") + + tasks = self.run_on_new_attr("test_entity_type", "test_attr_string", "e1", dp) + + self.assertDatapoint(tasks, attr="test_attr_int", v=4) + + def test_correlation_and_master_record_hooks(self): + record = {"eid": "e1", "test_attr_string": "abcdef"} + master_record = {"test_attr_float": {"v": 1.5}} + + tasks = self.run_correlation_hooks("test_entity_type", record, master_record) + + self.assertNoTasks(tasks) + self.assertRecordContains(record, test_attr_int=6, test_attr_float=1.5) + self.assertRecordAttr(record, "test_attr_int", 6) + + def test_record_unchanged_assertion(self): + record = {"eid": "e1", "labels": ["a", "b"]} + unchanged = copy.deepcopy(record) + + self.assertRecordUnchanged(record, unchanged) + + def test_empty_datapoint_task_is_rejected(self): + with self.assertRaises(ValidationError): + self.make_task("test_entity_type", "e1") + + def test_empty_datapoint_task_is_allowed_for_internal_synthetic_context(self): + with task_context(self.model_spec, allow_empty_data_point_task=True): + task = DataPointTask(etype="test_entity_type", eid="e1") + + self.assertEqual("test_entity_type", task.etype) + self.assertEqual("e1", task.eid) + self.assertEqual([], task.data_points) + + def test_no_datapoints_assertion_accepts_tasks_without_datapoints(self): + tasks = [self.make_task("test_entity_type", "e1", delete=True)] + + self.assertNoDatapoints(tasks) + + def test_timeseries_datapoint_helper_infers_t2_for_regular_timeseries(self): + t1 = datetime(2024, 1, 1, tzinfo=UTC) + + dp = self.make_timeseries_datapoint( + "test_entity_type", + "e1", + "test_attr_timeseries", + {"value": [1, 2, 3]}, + t1=t1, + ) + + self.assertEqual(t1, dp.t1) + self.assertEqual(t1 + timedelta(minutes=30), dp.t2) + self.assertEqual([1, 2, 3], dp.v.value) + + def test_periodic_update_hook(self): + tasks = self.run_periodic_update( + "test_entity_type", + "e1", + {"test_attr_string": {"v": "periodic"}}, + hook_id="sample", + ) + + self.assertDatapoint(tasks, attr="test_attr_string", v="periodic") + + def test_registration_assertions(self): + self.assert_registered("on_new_attr", entity="test_entity_type", attr="test_attr_string") + self.assert_registered_once("correlation", entity_type="test_entity_type") + self.assert_registered_attrs("test_entity_type", ["test_attr_string"]) + + def test_scheduler_helpers(self): + self.assert_scheduler_registered(func="reload", minute="*/5") + + self.run_scheduler_job("reload") + + self.assertEqual(1, self.module.reload_count) + + def test_deprecated_registrar_methods_are_supported_with_warnings(self): + registrar = self.make_registrar() + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always", DeprecationWarning) + self.make_module(DeprecatedHookModule, {}, registrar) + + self.assertEqual(2, len(captured)) + self.assertTrue(all(item.category is DeprecationWarning for item in captured)) + self.assertEqual(2, len(registrar.registrations)) + self.assertTrue(all(reg.deprecated for reg in registrar.registrations)) diff --git a/tests/test_common/test_snapshots.py b/tests/test_common/test_snapshots.py index ae39eed8..89590f27 100644 --- a/tests/test_common/test_snapshots.py +++ b/tests/test_common/test_snapshots.py @@ -3,14 +3,14 @@ import logging import os import unittest +from collections.abc import Callable +from datetime import UTC from functools import partial, update_wrapper -from typing import Callable, Optional from event_count_logger import DummyEventGroup from dp3.common.config import ModelSpec, PlatformConfig, read_config_dir from dp3.common.task import Task -from dp3.common.types import UTC from dp3.snapshots.snapshooter import SnapShooter from dp3.snapshots.snapshot_hooks import SnapshotCorrelationHookContainer @@ -115,7 +115,7 @@ def register_on_entity_delete( self, f_one: Callable[[str, str], None], f_many: Callable[[str, list[str]], None] ): ... - def get_module_cache(self, override_called_id: Optional[str] = None): + def get_module_cache(self, override_called_id: str | None = None): return self.module_cache def save_snapshot(self, etype: str, snapshot: dict, time: datetime): diff --git a/tests/test_common/test_types.py b/tests/test_common/test_types.py index 1c0e0fb0..3235293a 100644 --- a/tests/test_common/test_types.py +++ b/tests/test_common/test_types.py @@ -1,5 +1,5 @@ import unittest -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta, timezone from pydantic import BaseModel, Field @@ -18,7 +18,7 @@ class _T2Model(BaseModel): class TestAwareDatetime(unittest.TestCase): def test_naive_datetime_defaults_to_utc(self): model = _AwareModel(dt="2024-01-01T10:00:00") - self.assertEqual(model.dt.tzinfo, timezone.utc) + self.assertEqual(model.dt.tzinfo, UTC) def test_existing_timezone_is_preserved(self): cest_timezone = timezone(timedelta(hours=2), "CEST") @@ -29,6 +29,6 @@ def test_existing_timezone_is_preserved(self): def test_t2_datetime_inherits_timezone_when_missing(self): model = _T2Model(t1="2024-01-01T00:00:00") self.assertIsNotNone(model.t2) - self.assertEqual(model.t1.tzinfo, timezone.utc) - self.assertEqual(model.t2.tzinfo, timezone.utc) + self.assertEqual(model.t1.tzinfo, UTC) + self.assertEqual(model.t2.tzinfo, UTC) self.assertEqual(model.t2, model.t1) diff --git a/tests/test_config/db_entities/test_entity_type.yml b/tests/test_config/db_entities/test_entity_type.yml index 0caeef40..7ca7f98e 100644 --- a/tests/test_config/db_entities/test_entity_type.yml +++ b/tests/test_config/db_entities/test_entity_type.yml @@ -81,6 +81,16 @@ attribs: aggregate: true post_validity: 1h pre_validity: 1h + test_attr_timeseries: + name: test_attr_timeseries + type: timeseries + timeseries_type: regular + timeseries_params: + max_age: 7d + time_step: 10m + series: + value: + data_type: int test_attr_probability: name: test_attr_probability type: plain diff --git a/tests/test_example/dps_gen.py b/tests/test_example/dps_gen.py index 9001a162..83885434 100644 --- a/tests/test_example/dps_gen.py +++ b/tests/test_example/dps_gen.py @@ -2,12 +2,12 @@ import json import random -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta class TimeContainer: def __init__(self): - self.time = datetime.now(timezone.utc) - timedelta(days=4) + self.time = datetime.now(UTC) - timedelta(days=4) def add_minutes(self, minutes: int): self.time += timedelta(minutes=minutes) diff --git a/tests/test_example/dps_gen_realtime.py b/tests/test_example/dps_gen_realtime.py index 9f701e87..1d5979c3 100644 --- a/tests/test_example/dps_gen_realtime.py +++ b/tests/test_example/dps_gen_realtime.py @@ -2,14 +2,12 @@ import random from argparse import ArgumentParser -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from sys import stderr from time import sleep import requests -UTC = timezone.utc - def random_initial_location(): latitude = random.uniform(39.0, 41.0)