diff --git a/docs/hooks.md b/docs/hooks.md index 530d6121..02d09d16 100644 --- a/docs/hooks.md +++ b/docs/hooks.md @@ -600,6 +600,8 @@ Code reference: [`scheduler_register`][dp3.common.callback_registrar.CallbackReg Most user-facing hooks return `list[DataPointTask]`. Whenever that happens, the returned tasks are fed back into the main ingestion system. +Each returned task must do useful work: it must contain at least one datapoint, carry non-empty TTL tokens, or be a delete task. +Empty `DataPointTask` objects are rejected during validation because they would be queued and processed without changing DP3 state. This creates a feedback loop: diff --git a/docs/howto/test-module.md b/docs/howto/test-module.md new file mode 100644 index 00000000..3f3a687a --- /dev/null +++ b/docs/howto/test-module.md @@ -0,0 +1,153 @@ +# Test a secondary module + +DP3 includes helpers for writing focused unit tests for secondary modules without running a full +worker, database, message broker, or snapshot scheduler. + +Use [`DP3ModuleTestCase`][dp3.testing.DP3ModuleTestCase] when you want to instantiate a +module with the application's real `db_entities` model and then call registered hooks directly. +The test registrar captures callbacks during module initialization and exposes runners for the +common hook families. + +The config directory is read from the `DP3_CONFIG_DIR` environment variable unless a test class +sets `config_dir` explicitly. Module configuration is read from `modules.` in that +config by default, where `` is inferred from the module class' Python module name. + +```bash +DP3_CONFIG_DIR=config python -m unittest discover -s tests -v +``` + +## Basic pattern + +```python +from unittest.mock import patch + +from dp3.testing import DP3ModuleTestCase +from modules.ip_exposure_profile import IPExposureProfile + + +class TestIPExposureProfile(DP3ModuleTestCase): + module_class = IPExposureProfile + + def test_open_port_creates_service_and_link(self): + dp = self.make_observation_datapoint("ip", "192.0.2.1", "open_ports", 443) + + tasks = self.run_on_new_attr("ip", "open_ports", "192.0.2.1", dp) + + self.assertDatapoint(tasks, etype="service", eid="192.0.2.1:443", attr="guessed_type") + self.assertDatapoint(tasks, etype="ip", eid="192.0.2.1", attr="services") + + def test_updater_uses_mocked_external_lookup(self): + with patch.object(self.module, "_fetch_service_intel", return_value={"risk": "high"}): + tasks = self.run_periodic_update( + "service", + "192.0.2.1:443", + {"guessed_type": "https"}, + hook_id="service_intel", + ) + + self.assertDatapoint(tasks, attr="external_risk", v="high") +``` + +## What the helper provides + +`DP3ModuleTestCase`: + +- loads `db_entities` from `DP3_CONFIG_DIR` or `config_dir` and builds a real `ModelSpec`, +- creates a minimal `PlatformConfig`, +- instantiates `module_class` with a test registrar, +- creates validated `DataPointTask` and plain, observation, or timeseries datapoint objects using + the loaded model, +- calls registered hooks directly, +- provides partial-match assertions for emitted tasks, datapoints, and mutated records. + +The helper is intended for module-level unit tests. It does not run a database, task queues, +worker processes, recursive task ingestion, or full linked snapshot loading. + +## Datapoint helpers + +Use the datapoint helpers to build values accepted by the loaded model specification: + +```python +plain = self.make_plain_datapoint("ip", "192.0.2.1", "hostname", "host.example") +observation = self.make_observation_datapoint("ip", "192.0.2.1", "open_ports", 443) +timeseries = self.make_timeseries_datapoint( + "ip", + "192.0.2.1", + "traffic", + {"packets": [1, 2, 3], "bytes": [100, 200, 300]}, +) +``` + +For regular timeseries attributes, `make_timeseries_datapoint()` infers `t2` from `t1`, the +configured `time_step`, and the number of samples when `t2` is not supplied. + +## Hook runners + +Common runners are available on the test case: + +- `run_allow_entity_creation(entity, eid, task=None)` +- `run_on_entity_creation(entity, eid, task=None)` +- `run_on_new_attr(entity, attr, eid, dp)` +- `run_correlation_hooks(entity_type, record, master_record=None)` +- `run_periodic_update(entity_type, eid, master_record, hook_id=None)` +- `run_periodic_eid_update(entity_type, eid, hook_id=None)` +- `run_scheduler_job(index_or_func)` + +Correlation tests pass the snapshot `record` explicitly. The record must contain `eid`. +Scheduler jobs can be selected by registration index, callable, or callable name. + +## Assertions + +Assertions use partial matching: only fields supplied in the expected values are checked. + +```python +self.assertDatapoint(tasks, etype="ip", attr="hostname", v="example.test") +self.assertTaskEmitted(tasks, etype="ip", eid="192.0.2.1") +self.assertNoTasks(tasks) +self.assertNoDatapoints(tasks) +self.assertRecordContains(record, exposure_score=10) +self.assertRecordAttr(record, "exposure_score", 10) +self.assertRecordUnchanged(before, after) +``` + +Snake-case aliases are also available: `assert_datapoint`, `assert_task_emitted`, +`assert_no_tasks`, `assert_no_datapoints`, `assert_record_contains`, `assert_record_attr`, and +`assert_record_unchanged`. + +## Registration assertions + +Use registration assertions when a test needs to verify callback coverage or dynamic hook +registration. + +```python +self.assert_registered("on_new_attr", entity="ip", attr="hostname") +self.assert_registered_once("correlation", entity_type="service") +self.assert_registered_attrs("service", expected_service_attrs) +self.assert_scheduler_registered(func="reload_ip_groups", minute="*/10") +``` + +`assert_scheduler_registered()` accepts scheduler fields such as `minute`, `hour`, and `second`, +along with `func` for matching the registered callable by object or function name. + +## Mocking external dependencies + +Patch external constructors or functions before module instantiation when the dependency is created +in `__init__` or `load_config`: + +```python +class TestDNSModule(DP3ModuleTestCase): + module_class = DNSModule + + def setUp(self): + self.resolver_patcher = patch("modules.dns_module.Resolver", FakeResolver) + self.resolver_patcher.start() + self.addCleanup(self.resolver_patcher.stop) + super().setUp() +``` + +If patching is not convenient, use a test subclass as `module_class` and override the module's +initialization or dependency construction while keeping the hook methods under test unchanged. + +Deprecated registrar methods (`register_entity_hook` and `register_attr_hook`) are supported by the +test registrar and emit `DeprecationWarning`. Prefer the modern registration methods in new module +code and tests. diff --git a/docs/modules.md b/docs/modules.md index e2882436..822322d4 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -172,6 +172,16 @@ and configuration, see the [updater configuration](configuration/updater.md) pag - [`scheduler_register(...)`](hooks.md#scheduler_register) — CRON-style module-level scheduled callback for maintenance, polling, housekeeping, or shared-state reloads. +## Testing modules + +Secondary modules can be unit-tested without running a full DP3 worker by using +[`DP3ModuleTestCase`][dp3.testing.DP3ModuleTestCase]. The helper loads an application's +real `db_entities` model from `DP3_CONFIG_DIR` or an explicit test fixture path, instantiates a +module with a test callback registrar, and lets tests call registered hooks directly with validated +`DataPointTask` and datapoint objects. + +See [Test a secondary module](howto/test-module.md) for examples and supported hook runners. + ## Running module code in a separate thread The module is free to run its own code in separate threads or processes. diff --git a/dp3/common/callback_registrar.py b/dp3/common/callback_registrar.py index 675b4fcb..74c1d458 100644 --- a/dp3/common/callback_registrar.py +++ b/dp3/common/callback_registrar.py @@ -11,7 +11,7 @@ from dp3.common.datatype import AnyEidT from dp3.common.scheduler import Scheduler from dp3.common.state import SharedFlag -from dp3.common.task import DataPointTask +from dp3.common.task import DataPointTask, task_context from dp3.common.types import ParsedTimedelta from dp3.core.updater import Updater from dp3.snapshots.snapshooter import SnapShooter @@ -57,7 +57,8 @@ def on_entity_creation_in_snapshots( if not run_flag.isset(): return [] eid = record["eid"] - mock_task = DataPointTask(etype=etype, eid=eid, data_points=[]) + with task_context(model_spec, allow_empty_data_point_task=True): + mock_task = DataPointTask(etype=etype, eid=eid, data_points=[]) tasks = original_hook(eid, mock_task) write_datapoints_into_record(model_spec, tasks, record) return tasks @@ -74,7 +75,8 @@ def on_attr_change_in_snapshots( if not run_flag.isset(): return [] eid = record["eid"] - mock_task = DataPointTask(etype=etype, eid=eid, data_points=[]) + with task_context(model_spec, allow_empty_data_point_task=True): + mock_task = DataPointTask(etype=etype, eid=eid, data_points=[]) tasks = original_hook(eid, mock_task) if isinstance(tasks, list): write_datapoints_into_record(model_spec, tasks, record) diff --git a/dp3/common/task.py b/dp3/common/task.py index e2ece755..b59d85a7 100644 --- a/dp3/common/task.py +++ b/dp3/common/task.py @@ -17,7 +17,9 @@ Tag, TypeAdapter, ValidationError, + ValidationInfo, field_validator, + model_validator, ) from pydantic_core.core_schema import FieldValidationInfo @@ -40,9 +42,16 @@ def HASH(key: str) -> int: @contextmanager -def task_context(model_spec: ModelSpec) -> Iterator[None]: +def task_context( + model_spec: ModelSpec, *, allow_empty_data_point_task: bool = False +) -> Iterator[None]: """Context manager for setting the `model_spec` context variable.""" - token = _init_context_var.set({"model_spec": model_spec}) + token = _init_context_var.set( + { + "model_spec": model_spec, + "allow_empty_data_point_task": allow_empty_data_point_task, + } + ) try: yield finally: @@ -182,6 +191,18 @@ def validate_eid(cls, v, info: FieldValidationInfo): else: raise AssertionError("Missing `model_spec` in context") + @model_validator(mode="after") + def validate_not_empty(self, info: ValidationInfo): + context = info.context + if context and context.get("allow_empty_data_point_task"): + return self + if not self.data_points and not self.ttl_tokens and not self.delete: + raise ValueError( + "DataPointTask must contain at least one datapoint, non-empty ttl_tokens, " + "or be a delete task." + ) + return self + def parse_data_point_task(task: str, model_spec: ModelSpec) -> DataPointTask: with task_context(model_spec): diff --git a/dp3/task_processing/task_executor.py b/dp3/task_processing/task_executor.py index 5703820d..92199a67 100644 --- a/dp3/task_processing/task_executor.py +++ b/dp3/task_processing/task_executor.py @@ -193,10 +193,10 @@ def refresh_on_entity_creation( for master_record in self.db.get_worker_master_records( worker_id, worker_cnt, etype, projection=projection ): - with task_context(self.model_spec): + with task_context(self.model_spec, allow_empty_data_point_task=True): task = DataPointTask(etype=etype, eid=master_record["_id"]) - self.log.debug(f"Refreshing {etype}/{task.eid}") - new_tasks += self._task_entity_hooks[task.etype].run_on_creation(task.eid, task) + self.log.debug(f"Refreshing {etype}/{task.eid}") + new_tasks += self._task_entity_hooks[task.etype].run_on_creation(task.eid, task) return new_tasks diff --git a/dp3/testing/__init__.py b/dp3/testing/__init__.py new file mode 100644 index 00000000..350a6e9d --- /dev/null +++ b/dp3/testing/__init__.py @@ -0,0 +1,13 @@ +"""Testing helpers for DP3 applications.""" + +from dp3.testing.case import DP3ModuleTestCase +from dp3.testing.config import CONFIG_DIR_ENV, resolve_config_dir +from dp3.testing.registrar import HookRegistration, TestCallbackRegistrar + +__all__ = [ + "CONFIG_DIR_ENV", + "DP3ModuleTestCase", + "HookRegistration", + "TestCallbackRegistrar", + "resolve_config_dir", +] diff --git a/dp3/testing/assertions.py b/dp3/testing/assertions.py new file mode 100644 index 00000000..633a8263 --- /dev/null +++ b/dp3/testing/assertions.py @@ -0,0 +1,140 @@ +"""Assertion helpers for DP3 module tests.""" + +import unittest +from collections.abc import Iterable +from typing import Any + +from pydantic import BaseModel + +from dp3.common.datapoint import DataPointBase +from dp3.common.task import DataPointTask + +_UNSET = object() + + +class ModuleAssertions(unittest.TestCase): + """Partial-match assertions for module hook outputs.""" + + def assert_no_tasks(self, tasks: Iterable[DataPointTask]) -> None: + self.assertEqual([], list(tasks)) + + def assert_no_datapoints(self, tasks: Iterable[DataPointTask]) -> None: + self.assertEqual([], list(self.iter_datapoints(tasks))) + + def assert_task_emitted( # noqa: PLR0913 + self, + tasks: Iterable[DataPointTask], + *, + etype: Any = _UNSET, + eid: Any = _UNSET, + data_points: Any = _UNSET, + tags: Any = _UNSET, + ttl_tokens: Any = _UNSET, + delete: Any = _UNSET, + ) -> DataPointTask: + """Assert that a task matching the supplied ``DataPointTask`` fields was emitted.""" + expected = _selected_fields( + etype=etype, + eid=eid, + data_points=data_points, + tags=tags, + ttl_tokens=ttl_tokens, + delete=delete, + ) + task_list = list(tasks) + for task in task_list: + if self._partial_match(dump_value(task), expected): + return task + self.fail(f"No emitted task matched {expected!r}. Emitted tasks: {dump_value(task_list)!r}") + + def assert_datapoint( # noqa: PLR0913 + self, + tasks: Iterable[DataPointTask], + *, + etype: Any = _UNSET, + eid: Any = _UNSET, + attr: Any = _UNSET, + src: Any = _UNSET, + v: Any = _UNSET, + c: Any = _UNSET, + t1: Any = _UNSET, + t2: Any = _UNSET, + ) -> DataPointBase: + """Assert that a datapoint matching the supplied ``DataPointBase`` fields was emitted.""" + expected = _selected_fields( + etype=etype, + eid=eid, + attr=attr, + src=src, + v=v, + c=c, + t1=t1, + t2=t2, + ) + datapoints = list(self.iter_datapoints(tasks)) + for dp in datapoints: + if self._partial_match(dump_value(dp), expected): + return dp + self.fail( + f"No emitted datapoint matched {expected!r}. " + f"Emitted datapoints: {dump_value(datapoints)!r}" + ) + + def assert_record_contains(self, record: dict, **expected) -> None: + if not self._partial_match(record, expected): + self.fail(f"Record {record!r} does not contain expected values {expected!r}") + + def assert_record_attr(self, record: dict, attr: str, expected: Any) -> None: + if attr not in record: + self.fail(f"Record {record!r} does not contain attribute {attr!r}") + if not self._partial_match(record[attr], expected): + self.fail( + f"Record attribute {attr!r} value {record[attr]!r} " + f"does not match expected value {expected!r}" + ) + + def assert_record_unchanged(self, before: dict, after: dict) -> None: + self.assertEqual(dump_value(before), dump_value(after)) + + assertNoTasks = assert_no_tasks + assertNoDatapoints = assert_no_datapoints + assertTaskEmitted = assert_task_emitted + assertDatapoint = assert_datapoint + assertRecordContains = assert_record_contains + assertRecordAttr = assert_record_attr + assertRecordUnchanged = assert_record_unchanged + + @staticmethod + def iter_datapoints(tasks: Iterable[DataPointTask]) -> Iterable[DataPointBase]: + for task in tasks: + yield from task.data_points + + @classmethod + def _partial_match(cls, actual: Any, expected: Any) -> bool: + actual = dump_value(actual) + expected = dump_value(expected) + if isinstance(expected, dict): + if not isinstance(actual, dict): + return False + return all( + key in actual and cls._partial_match(actual[key], value) + for key, value in expected.items() + ) + return actual == expected + + +def _selected_fields(**fields: Any) -> dict[str, Any]: + return {key: value for key, value in fields.items() if value is not _UNSET} + + +def dump_value(value: Any) -> Any: + """Convert pydantic values recursively to plain Python containers.""" + if isinstance(value, BaseModel): + return value.model_dump() + if isinstance(value, list): + return [dump_value(item) for item in value] + if isinstance(value, tuple): + return tuple(dump_value(item) for item in value) + if isinstance(value, dict): + return {key: dump_value(item) for key, item in value.items()} + return value diff --git a/dp3/testing/case.py b/dp3/testing/case.py new file mode 100644 index 00000000..13197ff3 --- /dev/null +++ b/dp3/testing/case.py @@ -0,0 +1,338 @@ +"""unittest base class for DP3 secondary module tests.""" + +import copy +import unittest +from collections.abc import Iterable, Mapping, Sequence +from datetime import datetime +from typing import Any, Callable, Generic, Optional, TypeVar, Union + +from dp3.common.attrspec import AttrType +from dp3.common.base_module import BaseModule +from dp3.common.config import HierarchicalDict, ModelSpec, PlatformConfig +from dp3.common.datapoint import DataPointBase +from dp3.common.task import DataPointTask, task_context +from dp3.common.types import UTC +from dp3.common.utils import get_func_name +from dp3.testing.assertions import ModuleAssertions +from dp3.testing.config import ( + CONFIG_DIR_ENV, + build_model_spec, + build_platform_config, + get_module_config, + load_config, + resolve_config_dir, +) +from dp3.testing.registrar import HookRegistration, TestCallbackRegistrar + +ModuleT = TypeVar("ModuleT", bound=BaseModule) + + +class DP3ModuleTestCase(ModuleAssertions, unittest.TestCase, Generic[ModuleT]): + """Base class for unit tests of DP3 secondary modules. + + By default the app configuration directory is read from ``DP3_CONFIG_DIR``. Subclasses may set + ``config_dir`` explicitly when they need a fixed fixture config. + """ + + config_dir: Optional[str] = None + config_env_var: str = CONFIG_DIR_ENV + module_class: type[ModuleT] + module_name: Optional[str] = None + module_config: Optional[dict] = None + app_name: str = "test" + process_index: int = 0 + num_processes: int = 1 + module: ModuleT + + def setUp(self) -> None: + super().setUp() + self.config_base_path = self.resolve_config_dir() + self.config = self.load_config() + self.model_spec = self.make_model_spec(self.config) + self.platform_config = self.make_platform_config() + self.registrar = self.make_registrar() + self.module = self.make_module(self.module_class, self.get_module_config(), self.registrar) + + def resolve_config_dir(self) -> str: + return resolve_config_dir(self.config_dir, self.config_env_var) + + def load_config(self) -> HierarchicalDict: + return load_config(self.config_base_path, self.config_env_var) + + def make_model_spec(self, config: HierarchicalDict) -> ModelSpec: + return build_model_spec(config) + + def make_platform_config(self) -> PlatformConfig: + return build_platform_config( + self.config, + self.model_spec, + self.config_base_path, + app_name=self.app_name, + process_index=self.process_index, + num_processes=self.num_processes, + env_var=self.config_env_var, + ) + + def get_module_config(self) -> dict: + if self.module_config is not None: + return copy.deepcopy(self.module_config) + return copy.deepcopy(get_module_config(self.config, self.get_module_name())) + + def get_module_name(self) -> Optional[str]: + if self.module_name is not None: + return self.module_name + return self.module_class.__module__.split(".")[-1] + + def make_registrar(self) -> TestCallbackRegistrar: + return TestCallbackRegistrar(self.model_spec) + + def make_module( + self, + module_class: type[ModuleT], + module_config: dict[str, Any], + registrar: TestCallbackRegistrar, + ) -> ModuleT: + return module_class(self.platform_config, module_config, registrar) + + def make_task( + self, + etype: str, + eid: Any, + data_points: Optional[list[Union[dict, DataPointBase]]] = None, + tags: Optional[list] = None, + ttl_tokens: Optional[dict] = None, + delete: bool = False, + ) -> DataPointTask: + with task_context(self.model_spec): + return DataPointTask( + etype=etype, + eid=eid, + data_points=data_points or [], + tags=tags or [], + ttl_tokens=ttl_tokens, + delete=delete, + ) + + def make_datapoint( + self, + etype: str, + eid: Any, + attr: str, + v: Any, + src: str = "test", + **fields, + ) -> DataPointBase: + task = self.make_task( + etype, + eid, + [dict({"etype": etype, "eid": eid, "attr": attr, "src": src, "v": v}, **fields)], + ) + return task.data_points[0] + + def make_plain_datapoint( + self, etype: str, eid: Any, attr: str, v: Any, src: str = "test", **fields + ) -> DataPointBase: + return self.make_datapoint(etype, eid, attr, v, src=src, **fields) + + def make_observation_datapoint( + self, + etype: str, + eid: Any, + attr: str, + v: Any, + src: str = "test", + t1: Optional[datetime] = None, + t2: Optional[datetime] = None, + c: float = 1.0, + **fields, + ) -> DataPointBase: + t1 = t1 or datetime.now(UTC) + data = {"t1": t1, "c": c, **fields} + if t2 is not None: + data["t2"] = t2 + return self.make_datapoint(etype, eid, attr, v, src=src, **data) + + def make_timeseries_datapoint( + self, + etype: str, + eid: Any, + attr: str, + v: Mapping[str, Sequence[Any]], + src: str = "test", + t1: Optional[datetime] = None, + t2: Optional[datetime] = None, + **fields, + ) -> DataPointBase: + """Create a validated timeseries datapoint. + + For regular timeseries attributes, ``t2`` is inferred when omitted by using the + attribute's configured ``time_step`` and the number of samples in ``v``: + ``t2 = t1 + len(series) * time_step``. For irregular timeseries, ``t1`` is inferred from + the first ``time`` value when omitted. For irregular-interval timeseries, ``t1`` is inferred + from the first ``time_first`` value when omitted. + """ + attr_spec = self.model_spec.attributes[etype, attr] + if attr_spec.t != AttrType.TIMESERIES: + raise ValueError(f"Attribute {etype}/{attr} is not a timeseries attribute.") + + values = dict(v) + t1 = t1 or self._infer_timeseries_t1(attr_spec, values) or datetime.now(UTC) + if t2 is None and attr_spec.timeseries_type == "regular": + time_step = attr_spec.timeseries_params.time_step + if time_step is None: + raise ValueError(f"Regular timeseries attribute {etype}/{attr} has no time_step.") + t2 = t1 + self._timeseries_length(values) * time_step + + data = {"t1": t1, **fields} + if t2 is not None: + data["t2"] = t2 + return self.make_datapoint(etype, eid, attr, values, src=src, **data) + + @staticmethod + def _infer_timeseries_t1(attr_spec, values: Mapping[str, Sequence[Any]]) -> Optional[datetime]: + if attr_spec.timeseries_type == "irregular" and values.get("time"): + return values["time"][0] + if attr_spec.timeseries_type == "irregular_intervals" and values.get("time_first"): + return values["time_first"][0] + return None + + @staticmethod + def _timeseries_length(values: Mapping[str, Sequence[Any]]) -> int: + try: + return len(next(iter(values.values()))) + except StopIteration as e: + raise ValueError("Timeseries datapoint values cannot be empty.") from e + + def run_task_hooks(self, hook_type: str, task: DataPointTask) -> None: + self.registrar.run_task_hooks(hook_type, task) + + def run_allow_entity_creation( + self, entity: str, eid: Any, task: Optional[DataPointTask] = None + ) -> bool: + task = task or self._make_synthetic_task(entity, eid) + return self.registrar.run_allow_entity_creation(entity, eid, task) + + def run_on_entity_creation( + self, entity: str, eid: Any, task: Optional[DataPointTask] = None + ) -> list[DataPointTask]: + task = task or self._make_synthetic_task(entity, eid) + return self.registrar.run_on_entity_creation(entity, eid, task) + + def _make_synthetic_task(self, etype: str, eid: Any) -> DataPointTask: + with task_context(self.model_spec, allow_empty_data_point_task=True): + return DataPointTask(etype=etype, eid=eid) + + def run_on_new_attr(self, entity: str, attr: str, eid: Any, dp: DataPointBase): + return self.registrar.run_on_new_attr(entity, attr, eid, dp) + + def run_correlation_hooks( + self, + entity_type: str, + record: dict, + master_record: Optional[dict] = None, + ) -> list[DataPointTask]: + return self.registrar.run_correlation_hooks(entity_type, record, master_record) + + def run_periodic_update( + self, entity_type: str, eid: Any, master_record: dict, hook_id: Optional[str] = None + ) -> list[DataPointTask]: + return self.registrar.run_periodic_update(entity_type, eid, master_record, hook_id) + + def run_periodic_eid_update( + self, entity_type: str, eid: Any, hook_id: Optional[str] = None + ) -> list[DataPointTask]: + return self.registrar.run_periodic_eid_update(entity_type, eid, hook_id) + + def run_scheduler_job(self, job: Union[int, str, Callable, HookRegistration]): + return self.registrar.run_scheduler_job(job) + + def registered(self, kind: Optional[str] = None, **fields) -> list[HookRegistration]: + """Return registrations matching ``kind`` and the supplied registration fields.""" + return [ + registration + for registration in self.registrar.registrations + if self._registration_matches(registration, kind, fields) + ] + + def assert_registered(self, kind: str, **fields) -> HookRegistration: + """Assert that at least one callback registration matches the supplied fields.""" + matches = self.registered(kind, **fields) + if not matches: + self.fail( + f"No registration matched kind={kind!r}, fields={fields!r}. " + f"Registered callbacks: {self.registrar.registrations!r}" + ) + return matches[0] + + def assert_registered_once(self, kind: str, **fields) -> HookRegistration: + """Assert that exactly one callback registration matches the supplied fields.""" + matches = self.registered(kind, **fields) + if len(matches) != 1: + self.fail( + f"Expected one registration matching kind={kind!r}, fields={fields!r}; " + f"found {len(matches)}: {matches!r}" + ) + return matches[0] + + def assert_registered_attrs( + self, + entity: str, + expected_attrs: Iterable[str], + *, + kind: str = "on_new_attr", + exact: bool = True, + ) -> list[HookRegistration]: + """Assert that attribute hook registrations exist for the supplied entity attributes.""" + expected = set(expected_attrs) + matches = self.registered(kind, entity=entity) + actual = {registration.attr for registration in matches if registration.attr is not None} + if exact: + self.assertEqual(expected, actual) + else: + missing = expected - actual + if missing: + self.fail(f"Missing registrations for attributes: {sorted(missing)!r}") + return [registration for registration in matches if registration.attr in expected] + + def assert_scheduler_registered(self, **fields) -> HookRegistration: + """Assert that at least one scheduler callback registration matches the supplied fields.""" + return self.assert_registered("scheduler", **fields) + + assertRegistered = assert_registered + assertRegisteredOnce = assert_registered_once + assertRegisteredAttrs = assert_registered_attrs + assertSchedulerRegistered = assert_scheduler_registered + + def _registration_matches( + self, registration: HookRegistration, kind: Optional[str], fields: dict[str, Any] + ) -> bool: + if kind is not None and registration.kind != kind: + return False + for key, expected in fields.items(): + found, actual = _registration_field(registration, key) + if not found: + return False + if key in {"hook", "func"} and isinstance(expected, str): + if not _callable_name_matches(actual, expected): + return False + elif not self._partial_match(actual, expected): + return False + return True + + +def _registration_field(registration: HookRegistration, key: str) -> tuple[bool, Any]: + if key == "func": + return True, registration.hook + if hasattr(registration, key): + return True, getattr(registration, key) + if key in registration.extra: + return True, registration.extra[key] + schedule = registration.extra.get("schedule", {}) + if key in schedule: + return True, schedule[key] + return False, None + + +def _callable_name_matches(func: Callable, expected: str) -> bool: + func_name = get_func_name(func) + return func_name == expected or func_name.endswith(f".{expected}") diff --git a/dp3/testing/config.py b/dp3/testing/config.py new file mode 100644 index 00000000..11238f4f --- /dev/null +++ b/dp3/testing/config.py @@ -0,0 +1,66 @@ +"""Configuration helpers for DP3 module tests.""" + +import os +from contextlib import suppress +from typing import Optional + +from dp3.common.config import HierarchicalDict, ModelSpec, PlatformConfig, read_config_dir + +CONFIG_DIR_ENV = "DP3_CONFIG_DIR" + + +def resolve_config_dir(config_dir: Optional[str] = None, env_var: str = CONFIG_DIR_ENV) -> str: + """Return an absolute DP3 config directory path. + + Explicit ``config_dir`` values take precedence. If no explicit path is supplied, the path is + read from ``env_var``. + """ + resolved = config_dir or os.environ.get(env_var) + if not resolved: + raise ValueError( + f"DP3 module tests require a config directory. Set {env_var} or pass " + "config_dir explicitly." + ) + return os.path.abspath(resolved) + + +def load_config( + config_dir: Optional[str] = None, env_var: str = CONFIG_DIR_ENV +) -> HierarchicalDict: + """Load a DP3 config directory for module tests.""" + return read_config_dir(resolve_config_dir(config_dir, env_var), recursive=True) + + +def build_model_spec(config: HierarchicalDict) -> ModelSpec: + """Build a model specification from loaded DP3 configuration.""" + return ModelSpec(config.get("db_entities")) + + +def build_platform_config( + config: HierarchicalDict, + model_spec: ModelSpec, + config_dir: Optional[str] = None, + *, + app_name: str = "test", + process_index: int = 0, + num_processes: int = 1, + env_var: str = CONFIG_DIR_ENV, +) -> PlatformConfig: + """Build the minimal platform config needed by secondary module unit tests.""" + with suppress(Exception): + num_processes = config.get("processing_core.worker_processes", num_processes) + return PlatformConfig( + app_name=app_name, + config_base_path=resolve_config_dir(config_dir, env_var), + config=config, + model_spec=model_spec, + process_index=process_index, + num_processes=num_processes, + ) + + +def get_module_config(config: HierarchicalDict, module_name: Optional[str]) -> dict: + """Return module-specific config from loaded app config.""" + if module_name is None: + return {} + return config.get(f"modules.{module_name}", {}) diff --git a/dp3/testing/registrar.py b/dp3/testing/registrar.py new file mode 100644 index 00000000..56e13936 --- /dev/null +++ b/dp3/testing/registrar.py @@ -0,0 +1,485 @@ +"""Test callback registrar for DP3 secondary modules.""" + +import copy +import logging +import warnings +from collections import defaultdict +from dataclasses import dataclass, field +from functools import wraps +from typing import Any, Callable, Optional, Union + +from event_count_logger import DummyEventGroup + +from dp3.common.attrspec import AttrType +from dp3.common.config import ModelSpec +from dp3.common.datapoint import DataPointBase +from dp3.common.task import DataPointTask, task_context +from dp3.common.utils import get_func_name +from dp3.snapshots.snapshot_hooks import ( + SnapshotCorrelationHookContainer, + SnapshotTimeseriesHookContainer, +) + + +@dataclass +class HookRegistration: + """Captured callback registration made by a secondary module.""" + + kind: str + hook: Callable + entity: Optional[str] = None + attr: Optional[str] = None + hook_type: Optional[str] = None + hook_id: Optional[str] = None + entity_type: Optional[str] = None + attr_type: Optional[str] = None + depends_on: list[list[str]] = field(default_factory=list) + may_change: list[list[str]] = field(default_factory=list) + refresh: Any = None + period: Any = None + deprecated: bool = False + extra: dict[str, Any] = field(default_factory=dict) + + +def _drop_master_for_test(hook: Callable[[str, dict], Any]) -> Callable[[str, dict, dict], Any]: + @wraps(hook) + def wrapped(entity_type: str, record: dict, _master_record: dict): + return hook(entity_type, record) + + return wrapped + + +class TestCallbackRegistrar: + """Callback registrar implementation for module unit tests.""" + + attr_spec_t_to_on_attr = { + AttrType.PLAIN: "on_new_plain", + AttrType.OBSERVATIONS: "on_new_observation", + AttrType.TIMESERIES: "on_new_ts_chunk", + } + + def __init__(self, model_spec: ModelSpec, log: Optional[logging.Logger] = None): + self.model_spec = model_spec + self.log = log or logging.getLogger(self.__class__.__name__) + self.registrations: list[HookRegistration] = [] + self._task_hooks: defaultdict[str, list[HookRegistration]] = defaultdict(list) + self._allow_creation_hooks: defaultdict[str, list[HookRegistration]] = defaultdict(list) + self._on_creation_hooks: defaultdict[str, list[HookRegistration]] = defaultdict(list) + self._attr_hooks: defaultdict[tuple[str, str], list[HookRegistration]] = defaultdict(list) + self._snapshot_init_hooks: list[HookRegistration] = [] + self._snapshot_finalize_hooks: list[HookRegistration] = [] + self._periodic_record_hooks: defaultdict[tuple[str, str], list[HookRegistration]] = ( + defaultdict(list) + ) + self._periodic_eid_hooks: defaultdict[tuple[str, str], list[HookRegistration]] = ( + defaultdict(list) + ) + self._scheduler_jobs: list[HookRegistration] = [] + + self._correlation_hooks = SnapshotCorrelationHookContainer( + self.log, model_spec, DummyEventGroup() + ) + self._timeseries_hooks = SnapshotTimeseriesHookContainer( + self.log, model_spec, DummyEventGroup() + ) + + def scheduler_register( # noqa: PLR0913 + self, + func: Callable, + *, + func_args: Union[list, tuple] = None, + func_kwargs: dict = None, + year: Union[int, str] = None, + month: Union[int, str] = None, + day: Union[int, str] = None, + week: Union[int, str] = None, + day_of_week: Union[int, str] = None, + hour: Union[int, str] = None, + minute: Union[int, str] = None, + second: Union[int, str] = None, + timezone: str = "UTC", + misfire_grace_time: int = 1, + ) -> int: + schedule = { + "year": year, + "month": month, + "day": day, + "week": week, + "day_of_week": day_of_week, + "hour": hour, + "minute": minute, + "second": second, + "timezone": timezone, + "misfire_grace_time": misfire_grace_time, + } + reg = HookRegistration( + kind="scheduler", + hook=func, + extra={ + "func_args": list(func_args or []), + "func_kwargs": dict(func_kwargs or {}), + "schedule": schedule, + }, + ) + self._record(reg) + self._scheduler_jobs.append(reg) + return len(self._scheduler_jobs) - 1 + + def register_task_hook(self, hook_type: str, hook: Callable): + if hook_type != "on_task_start": + raise ValueError(f"Hook type '{hook_type}' doesn't exist.") + reg = HookRegistration(kind="task", hook_type=hook_type, hook=hook) + self._record(reg) + self._task_hooks[hook_type].append(reg) + + def register_allow_entity_creation_hook(self, hook: Callable, entity: str): + self._validate_entity(entity) + reg = HookRegistration(kind="allow_entity_creation", entity=entity, hook=hook) + self._record(reg) + self._allow_creation_hooks[entity].append(reg) + + def register_on_entity_creation_hook( + self, + hook: Callable, + entity: str, + refresh: Any = None, + may_change: list[list[str]] = None, + ): + self._validate_entity(entity) + if refresh is not None and may_change is None: + raise ValueError("'may_change' must be specified if 'refresh' is specified") + reg = HookRegistration( + kind="on_entity_creation", + entity=entity, + hook=hook, + refresh=refresh, + may_change=copy.deepcopy(may_change or []), + ) + self._record(reg) + self._on_creation_hooks[entity].append(reg) + + def register_on_new_attr_hook( + self, + hook: Callable, + entity: str, + attr: str, + refresh: Any = None, + may_change: list[list[str]] = None, + ): + hook_type = self._hook_type_for_attr(entity, attr) + if refresh is not None and may_change is None: + raise ValueError("'may_change' must be specified if 'refresh' is specified") + reg = HookRegistration( + kind="on_new_attr", + hook_type=hook_type, + entity=entity, + attr=attr, + hook=hook, + refresh=refresh, + may_change=copy.deepcopy(may_change or []), + ) + self._record(reg) + self._attr_hooks[entity, attr].append(reg) + + def register_entity_hook(self, hook_type: str, hook: Callable, entity: str): + warnings.warn( + "register_entity_hook() is deprecated; use " + "register_allow_entity_creation_hook() or register_on_entity_creation_hook().", + DeprecationWarning, + stacklevel=2, + ) + self._validate_entity(entity) + if hook_type == "allow_entity_creation": + reg = HookRegistration( + kind="allow_entity_creation", entity=entity, hook=hook, deprecated=True + ) + self._record(reg) + self._allow_creation_hooks[entity].append(reg) + return + if hook_type == "on_entity_creation": + reg = HookRegistration( + kind="on_entity_creation", entity=entity, hook=hook, deprecated=True + ) + self._record(reg) + self._on_creation_hooks[entity].append(reg) + return + raise ValueError(f"Hook type '{hook_type}' doesn't exist.") + + def register_attr_hook(self, hook_type: str, hook: Callable, entity: str, attr: str): + warnings.warn( + "register_attr_hook() is deprecated; use register_on_new_attr_hook().", + DeprecationWarning, + stacklevel=2, + ) + expected_hook_type = self._hook_type_for_attr(entity, attr) + if hook_type != expected_hook_type: + raise ValueError(f"Hook type '{hook_type}' doesn't exist for {entity}/{attr}.") + reg = HookRegistration( + kind="on_new_attr", + hook_type=hook_type, + entity=entity, + attr=attr, + hook=hook, + deprecated=True, + ) + self._record(reg) + self._attr_hooks[entity, attr].append(reg) + + def register_timeseries_hook(self, hook: Callable, entity_type: str, attr_type: str): + self._timeseries_hooks.register(hook, entity_type, attr_type) + reg = HookRegistration( + kind="timeseries", hook=hook, entity_type=entity_type, attr_type=attr_type + ) + self._record(reg) + + def register_correlation_hook( + self, + hook: Callable, + entity_type: str, + depends_on: list[list[str]], + may_change: list[list[str]], + ): + wrapped = _drop_master_for_test(hook) + hook_id = self._correlation_hooks.register(wrapped, entity_type, depends_on, may_change) + reg = HookRegistration( + kind="correlation", + hook=hook, + hook_id=hook_id, + entity_type=entity_type, + depends_on=copy.deepcopy(depends_on), + may_change=copy.deepcopy(may_change), + ) + self._record(reg) + + def register_correlation_hook_with_master_record( + self, + hook: Callable, + entity_type: str, + depends_on: list[list[str]], + may_change: list[list[str]], + ): + hook_id = self._correlation_hooks.register(hook, entity_type, depends_on, may_change) + reg = HookRegistration( + kind="correlation_with_master_record", + hook=hook, + hook_id=hook_id, + entity_type=entity_type, + depends_on=copy.deepcopy(depends_on), + may_change=copy.deepcopy(may_change), + ) + self._record(reg) + + def register_snapshot_init_hook(self, hook: Callable): + reg = HookRegistration(kind="snapshot_init", hook=hook) + self._record(reg) + self._snapshot_init_hooks.append(reg) + + def register_snapshot_finalize_hook(self, hook: Callable): + reg = HookRegistration(kind="snapshot_finalize", hook=hook) + self._record(reg) + self._snapshot_finalize_hooks.append(reg) + + def register_periodic_update_hook( + self, hook: Callable, hook_id: str, entity_type: str, period: Any + ): + self._validate_entity(entity_type) + reg = HookRegistration( + kind="periodic_update", + hook=hook, + hook_id=hook_id, + entity_type=entity_type, + period=period, + ) + self._record(reg) + self._periodic_record_hooks[entity_type, hook_id].append(reg) + + def register_periodic_eid_update_hook( + self, hook: Callable, hook_id: str, entity_type: str, period: Any + ): + self._validate_entity(entity_type) + reg = HookRegistration( + kind="periodic_eid_update", + hook=hook, + hook_id=hook_id, + entity_type=entity_type, + period=period, + ) + self._record(reg) + self._periodic_eid_hooks[entity_type, hook_id].append(reg) + + def run_task_hooks(self, hook_type: str, task: DataPointTask) -> None: + for reg in self._task_hooks[hook_type]: + reg.hook(task) + + def run_allow_entity_creation(self, entity: str, eid: Any, task: DataPointTask) -> bool: + return all(reg.hook(eid, task) for reg in self._allow_creation_hooks[entity]) + + def run_on_entity_creation( + self, entity: str, eid: Any, task: DataPointTask + ) -> list[DataPointTask]: + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for reg in self._on_creation_hooks[entity]: + hook_tasks = reg.hook(eid, task) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def run_on_new_attr(self, entity: str, attr: str, eid: Any, dp: DataPointBase): + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for reg in self._attr_hooks[entity, attr]: + hook_tasks = reg.hook(eid, dp) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def run_timeseries_hook( + self, entity_type: str, attr_type: str, attr_history: list[dict] + ) -> list[DataPointTask]: + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for hook in self._timeseries_hooks._hooks[entity_type, attr_type]: + hook_tasks = hook(entity_type, attr_type, attr_history) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def run_correlation_hooks( + self, + entity_type: str, + record: dict, + master_record: Optional[dict] = None, + ) -> list[DataPointTask]: + eid = self._assert_record_eid(record) + return self.run_correlation_hooks_for_entities( + {(entity_type, eid): record}, {(entity_type, eid): master_record or {}} + ) + + def run_correlation_hooks_for_entities( + self, entities: dict[tuple[str, Any], dict], master_records: Optional[dict] = None + ) -> list[DataPointTask]: + master_records = master_records or {} + for entity_type, _ in entities: + self._validate_entity(entity_type) + for record in entities.values(): + self._assert_record_eid(record) + + entity_types = {etype for etype, _ in entities} + hook_subset = [ + (hook_id, hook, etype) + for etype in entity_types + for hook_id, hook in self._correlation_hooks._hooks[etype] + ] + topological_order = self._correlation_hooks._dependency_graph.topological_order + hook_subset.sort(key=lambda item: topological_order.index(item[0])) + + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for _, hook, etype in hook_subset: + for entity_key, record in entities.items(): + if entity_key[0] != etype: + continue + hook_tasks = hook(etype, record, master_records.get(entity_key, {})) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def run_snapshot_init_hooks(self) -> list[DataPointTask]: + return self._run_no_arg_hooks(self._snapshot_init_hooks) + + def run_snapshot_finalize_hooks(self) -> list[DataPointTask]: + return self._run_no_arg_hooks(self._snapshot_finalize_hooks) + + def run_periodic_update( + self, + entity_type: str, + eid: Any, + master_record: dict, + hook_id: Optional[str] = None, + ) -> list[DataPointTask]: + hooks = self._matching_update_hooks(self._periodic_record_hooks, entity_type, hook_id) + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for reg in hooks: + hook_tasks = reg.hook(entity_type, eid, master_record) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def run_periodic_eid_update( + self, entity_type: str, eid: Any, hook_id: Optional[str] = None + ) -> list[DataPointTask]: + hooks = self._matching_update_hooks(self._periodic_eid_hooks, entity_type, hook_id) + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for reg in hooks: + hook_tasks = reg.hook(entity_type, eid) + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + def get_scheduler_job( + self, job: Union[int, str, Callable, HookRegistration] + ) -> HookRegistration: + """Return a registered scheduler job by index, callable, or callable name.""" + if isinstance(job, int): + return self._scheduler_jobs[job] + if isinstance(job, HookRegistration): + if job.kind != "scheduler": + raise ValueError(f"Registration kind '{job.kind}' is not a scheduler job.") + return job + + matches = [reg for reg in self._scheduler_jobs if _callable_matches(reg.hook, job)] + if not matches: + raise ValueError(f"No scheduler job matches {job!r}.") + if len(matches) > 1: + raise ValueError(f"Multiple scheduler jobs match {job!r}.") + return matches[0] + + def run_scheduler_job(self, job: Union[int, str, Callable, HookRegistration]): + reg = self.get_scheduler_job(job) + return reg.hook(*reg.extra["func_args"], **reg.extra["func_kwargs"]) + + def _record(self, registration: HookRegistration) -> None: + self.registrations.append(registration) + + def _validate_entity(self, entity: str) -> None: + if entity not in self.model_spec.entities: + raise ValueError(f"Entity '{entity}' does not exist.") + + def _hook_type_for_attr(self, entity: str, attr: str) -> str: + try: + return self.attr_spec_t_to_on_attr[self.model_spec.attributes[entity, attr].t] + except KeyError as e: + raise ValueError( + f"Cannot register hook for attribute {entity}/{attr}, are you sure it exists?" + ) from e + + @staticmethod + def _assert_record_eid(record: dict) -> Any: + if "eid" not in record: + raise ValueError("Correlation hook records must contain an 'eid' field.") + return record["eid"] + + def _run_no_arg_hooks(self, hooks: list[HookRegistration]) -> list[DataPointTask]: + tasks: list[DataPointTask] = [] + with task_context(self.model_spec): + for reg in hooks: + hook_tasks = reg.hook() + if isinstance(hook_tasks, list): + tasks.extend(hook_tasks) + return tasks + + @staticmethod + def _matching_update_hooks(hooks: dict, entity_type: str, hook_id: Optional[str]): + if hook_id is not None: + return list(hooks[entity_type, hook_id]) + return [reg for (etype, _), regs in hooks.items() if etype == entity_type for reg in regs] + + +def _callable_matches(func: Callable, expected: Union[str, Callable]) -> bool: + if callable(expected): + return func == expected + func_name = get_func_name(func) + return func_name == expected or func_name.endswith(f".{expected}") diff --git a/mkdocs.yml b/mkdocs.yml index 44ae7b21..32695963 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,6 +49,7 @@ nav: - Add an attribute: howto/add-attribute.md - Add an input module: howto/add-input.md - Add a secondary module: howto/add-module.md + - Test a secondary module: howto/test-module.md - Deploy an app: howto/deploy-app.md - Develop DP3 itself: howto/develop-dp3.md - Extend Docs: howto/extending.md diff --git a/tests/test_common/test_module_testing.py b/tests/test_common/test_module_testing.py new file mode 100644 index 00000000..9f7c57c0 --- /dev/null +++ b/tests/test_common/test_module_testing.py @@ -0,0 +1,228 @@ +import copy +import os +import warnings +from datetime import datetime, timedelta + +from pydantic import ValidationError + +from dp3.common.base_module import BaseModule +from dp3.common.task import DataPointTask, task_context +from dp3.common.types import UTC +from dp3.testing import DP3ModuleTestCase + + +class SampleModule(BaseModule): + def __init__(self, platform_config, module_config, registrar): + super().__init__(platform_config, module_config, registrar) + registrar.register_allow_entity_creation_hook(self.allow_create, "test_entity_type") + registrar.register_on_entity_creation_hook(self.on_create, "test_entity_type") + registrar.register_on_new_attr_hook(self.on_string, "test_entity_type", "test_attr_string") + registrar.register_correlation_hook( + self.copy_string, + "test_entity_type", + depends_on=[["test_attr_string"]], + may_change=[["test_attr_int"]], + ) + registrar.register_correlation_hook_with_master_record( + self.copy_master_float, + "test_entity_type", + depends_on=[], + may_change=[["test_attr_float"]], + ) + registrar.register_periodic_update_hook( + self.periodic_record, "sample", "test_entity_type", "1d" + ) + registrar.scheduler_register(self.reload, minute="*/5") + + def load_config(self, _config, module_config): + self.prefix = module_config.get("prefix", "created") + self.reload_count = 0 + + def allow_create(self, eid, _task): + return eid != "deny" + + def on_create(self, eid, _task): + return [ + DataPointTask( + etype="test_entity_type", + eid=eid, + data_points=[ + { + "etype": "test_entity_type", + "eid": eid, + "attr": "test_attr_string", + "src": "secondary/sample", + "v": f"{self.prefix}:{eid}", + } + ], + ) + ] + + def on_string(self, eid, dp): + return [ + DataPointTask( + etype="test_entity_type", + eid=eid, + data_points=[ + { + "etype": "test_entity_type", + "eid": eid, + "attr": "test_attr_int", + "src": "secondary/sample", + "v": len(dp.v), + } + ], + ) + ] + + def copy_string(self, _entity_type, record): + record["test_attr_int"] = len(record["test_attr_string"]) + + def copy_master_float(self, _entity_type, record, master_record): + record["test_attr_float"] = master_record["test_attr_float"]["v"] + + def periodic_record(self, entity_type, eid, master_record): + return [ + DataPointTask( + etype=entity_type, + eid=eid, + data_points=[ + { + "etype": entity_type, + "eid": eid, + "attr": "test_attr_string", + "src": "secondary/sample", + "v": master_record["test_attr_string"]["v"], + } + ], + ) + ] + + def reload(self): + self.reload_count += 1 + + +class DeprecatedHookModule(BaseModule): + def __init__(self, platform_config, module_config, registrar): + super().__init__(platform_config, module_config, registrar) + registrar.register_entity_hook("on_entity_creation", self.on_create, "test_entity_type") + registrar.register_attr_hook( + "on_new_plain", self.on_string, "test_entity_type", "test_attr_string" + ) + + def on_create(self, _eid, _task): + return [] + + def on_string(self, _eid, _dp): + return [] + + +TEST_CONFIG_DIR = os.path.join(os.path.dirname(__file__), "..", "test_config") + + +class TestDP3ModuleTestCase(DP3ModuleTestCase): + config_dir = TEST_CONFIG_DIR + module_class = SampleModule + module_config = {"prefix": "hello"} + + def test_entity_creation_hook_and_partial_datapoint_assertion(self): + self.assertTrue(self.run_allow_entity_creation("test_entity_type", "ok")) + self.assertFalse(self.run_allow_entity_creation("test_entity_type", "deny")) + + tasks = self.run_on_entity_creation("test_entity_type", "ok") + + self.assertDatapoint( + tasks, + etype="test_entity_type", + eid="ok", + attr="test_attr_string", + v="hello:ok", + ) + + def test_attribute_hook(self): + dp = self.make_plain_datapoint("test_entity_type", "e1", "test_attr_string", "abcd") + + tasks = self.run_on_new_attr("test_entity_type", "test_attr_string", "e1", dp) + + self.assertDatapoint(tasks, attr="test_attr_int", v=4) + + def test_correlation_and_master_record_hooks(self): + record = {"eid": "e1", "test_attr_string": "abcdef"} + master_record = {"test_attr_float": {"v": 1.5}} + + tasks = self.run_correlation_hooks("test_entity_type", record, master_record) + + self.assertNoTasks(tasks) + self.assertRecordContains(record, test_attr_int=6, test_attr_float=1.5) + self.assertRecordAttr(record, "test_attr_int", 6) + + def test_record_unchanged_assertion(self): + record = {"eid": "e1", "labels": ["a", "b"]} + unchanged = copy.deepcopy(record) + + self.assertRecordUnchanged(record, unchanged) + + def test_empty_datapoint_task_is_rejected(self): + with self.assertRaises(ValidationError): + self.make_task("test_entity_type", "e1") + + def test_empty_datapoint_task_is_allowed_for_internal_synthetic_context(self): + with task_context(self.model_spec, allow_empty_data_point_task=True): + task = DataPointTask(etype="test_entity_type", eid="e1") + + self.assertEqual("test_entity_type", task.etype) + self.assertEqual("e1", task.eid) + self.assertEqual([], task.data_points) + + def test_no_datapoints_assertion_accepts_tasks_without_datapoints(self): + tasks = [self.make_task("test_entity_type", "e1", delete=True)] + + self.assertNoDatapoints(tasks) + + def test_timeseries_datapoint_helper_infers_t2_for_regular_timeseries(self): + t1 = datetime(2024, 1, 1, tzinfo=UTC) + + dp = self.make_timeseries_datapoint( + "test_entity_type", + "e1", + "test_attr_timeseries", + {"value": [1, 2, 3]}, + t1=t1, + ) + + self.assertEqual(t1, dp.t1) + self.assertEqual(t1 + timedelta(minutes=30), dp.t2) + self.assertEqual([1, 2, 3], dp.v.value) + + def test_periodic_update_hook(self): + tasks = self.run_periodic_update( + "test_entity_type", + "e1", + {"test_attr_string": {"v": "periodic"}}, + hook_id="sample", + ) + + self.assertDatapoint(tasks, attr="test_attr_string", v="periodic") + + def test_registration_assertions(self): + self.assert_registered("on_new_attr", entity="test_entity_type", attr="test_attr_string") + self.assert_registered_once("correlation", entity_type="test_entity_type") + self.assert_registered_attrs("test_entity_type", ["test_attr_string"]) + + def test_scheduler_helpers(self): + self.assert_scheduler_registered(func="reload", minute="*/5") + + self.run_scheduler_job("reload") + + self.assertEqual(1, self.module.reload_count) + + def test_deprecated_registrar_methods_are_supported_with_warnings(self): + registrar = self.make_registrar() + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always", DeprecationWarning) + self.make_module(DeprecatedHookModule, {}, registrar) + + self.assertEqual(2, len(captured)) + self.assertTrue(all(item.category is DeprecationWarning for item in captured)) + self.assertEqual(2, len(registrar.registrations)) + self.assertTrue(all(reg.deprecated for reg in registrar.registrations)) diff --git a/tests/test_config/db_entities/test_entity_type.yml b/tests/test_config/db_entities/test_entity_type.yml index 0caeef40..7ca7f98e 100644 --- a/tests/test_config/db_entities/test_entity_type.yml +++ b/tests/test_config/db_entities/test_entity_type.yml @@ -81,6 +81,16 @@ attribs: aggregate: true post_validity: 1h pre_validity: 1h + test_attr_timeseries: + name: test_attr_timeseries + type: timeseries + timeseries_type: regular + timeseries_params: + max_age: 7d + time_step: 10m + series: + value: + data_type: int test_attr_probability: name: test_attr_probability type: plain