diff --git a/ci/lint.sh b/ci/lint.sh index 34ef27b..c405891 100755 --- a/ci/lint.sh +++ b/ci/lint.sh @@ -1,4 +1,4 @@ #!/usr/bin/env sh . .venv/bin/activate -pylint src/ tests/unit/ +pylint src/ tests/unit/ examples/ diff --git a/ci/typecheck.sh b/ci/typecheck.sh index 4fa4ffe..706035c 100755 --- a/ci/typecheck.sh +++ b/ci/typecheck.sh @@ -1,4 +1,4 @@ #!/usr/bin/env sh . .venv/bin/activate -mypy --config-file mypy.ini src/ ./tests/unit/ +mypy --config-file mypy.ini src/ ./tests/unit/ examples/ diff --git a/dev-requirements.txt b/dev-requirements.txt index 70c1bc9..183afa3 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -8,19 +8,19 @@ alabaster==0.7.13 # via sphinx annotated-types==0.7.0 # via pydantic -argcomplete==3.4.0 +argcomplete==3.5.0 # via datamodel-code-generator astroid==3.2.4 # via pylint -babel==2.15.0 +babel==2.16.0 # via sphinx -black==24.4.2 +black==24.8.0 # via datamodel-code-generator build==1.2.1 # via pip-tools -cachetools==5.4.0 +cachetools==5.5.0 # via tox -certifi==2024.7.4 +certifi==2024.8.30 # via requests cfgv==3.4.0 # via pre-commit @@ -35,11 +35,9 @@ click==8.1.7 # s2-python (setup.cfg) colorama==0.4.6 # via tox -coverage[toml]==7.6.0 - # via - # coverage - # pytest-cov -datamodel-code-generator==0.25.8 +coverage[toml]==7.6.1 + # via pytest-cov +datamodel-code-generator==0.26.0 # via s2-python (setup.cfg) dill==0.3.8 # via pylint @@ -64,13 +62,13 @@ genson==1.3.0 # via datamodel-code-generator identify==2.6.0 # via pre-commit -idna==3.7 +idna==3.8 # via # email-validator # requests imagesize==1.4.1 # via sphinx -importlib-metadata==8.2.0 +importlib-metadata==8.4.0 # via # build # sphinx @@ -90,7 +88,7 @@ markupsafe==2.1.5 # via jinja2 mccabe==0.7.0 # via pylint -mypy==1.11.0 +mypy==1.11.2 # via s2-python (setup.cfg) mypy-extensions==1.0.0 # via @@ -133,7 +131,7 @@ pygments==2.18.0 # via # sphinx # sphinx-tabs -pylint==3.2.6 +pylint==3.2.7 # via s2-python (setup.cfg) pyproject-api==1.7.1 # via tox @@ -158,7 +156,7 @@ pytz==2024.1 # via # babel # s2-python (setup.cfg) -pyyaml==6.0.1 +pyyaml==6.0.2 # via # datamodel-code-generator # pre-commit @@ -214,9 +212,9 @@ tomli==2.0.1 # pyproject-api # pytest # tox -tomlkit==0.13.0 +tomlkit==0.13.2 # via pylint -tox==4.16.0 +tox==4.18.0 # via s2-python (setup.cfg) types-pytz==2024.1.0.20240417 # via s2-python (setup.cfg) @@ -235,9 +233,11 @@ virtualenv==20.26.3 # via # pre-commit # tox -wheel==0.43.0 +websockets==13.0.1 + # via s2-python (setup.cfg) +wheel==0.44.0 # via pip-tools -zipp==3.19.2 +zipp==3.20.1 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/examples/example_frbc_rm.py b/examples/example_frbc_rm.py new file mode 100644 index 0000000..bb05bc8 --- /dev/null +++ b/examples/example_frbc_rm.py @@ -0,0 +1,172 @@ +import logging +import sys +import uuid +import signal +import datetime +from typing import Callable + +from s2python.common import ( + EnergyManagementRole, + Duration, + Role, + RoleType, + Commodity, + Currency, + NumberRange, + PowerRange, + CommodityQuantity, +) +from s2python.frbc import ( + FRBCInstruction, + FRBCSystemDescription, + FRBCActuatorDescription, + FRBCStorageDescription, + FRBCOperationMode, + FRBCOperationModeElement, + FRBCFillLevelTargetProfile, + FRBCFillLevelTargetProfileElement, + FRBCStorageStatus, + FRBCActuatorStatus, +) +from s2python.s2_connection import S2Connection, AssetDetails +from s2python.s2_control_type import FRBCControlType, NoControlControlType +from s2python.validate_values_mixin import S2Message + +logger = logging.getLogger("s2python") +logger.addHandler(logging.StreamHandler(sys.stdout)) +logger.setLevel(logging.DEBUG) + + +class MyFRBCControlType(FRBCControlType): + def handle_instruction( + self, conn: S2Connection, msg: S2Message, send_okay: Callable[[], None] + ) -> None: + if not isinstance(msg, FRBCInstruction): + raise RuntimeError( + f"Expected an FRBCInstruction but received a message of type {type(msg)}." + ) + print(f"I have received the message {msg} from {conn}") + + def activate(self, conn: S2Connection) -> None: + print("The control type FRBC is now activated.") + + print("Time to send a FRBC SystemDescription") + actuator_id = uuid.uuid4() + operation_mode_id = uuid.uuid4() + conn.send_msg_and_await_reception_status_sync( + FRBCSystemDescription( + message_id=uuid.uuid4(), + valid_from=datetime.datetime.now(tz=datetime.timezone.utc), + actuators=[ + FRBCActuatorDescription( + id=actuator_id, + operation_modes=[ + FRBCOperationMode( + id=operation_mode_id, + elements=[ + FRBCOperationModeElement( + fill_level_range=NumberRange( + start_of_range=0.0, end_of_range=100.0 + ), + fill_rate=NumberRange( + start_of_range=-5.0, end_of_range=5.0 + ), + power_ranges=[ + PowerRange( + start_of_range=-200.0, + end_of_range=200.0, + commodity_quantity=CommodityQuantity.ELECTRIC_POWER_L1, + ) + ], + ) + ], + diagnostic_label="Load & unload battery", + abnormal_condition_only=False, + ) + ], + transitions=[], + timers=[], + supported_commodities=[Commodity.ELECTRICITY], + ) + ], + storage=FRBCStorageDescription( + fill_level_range=NumberRange(start_of_range=0.0, end_of_range=100.0), + fill_level_label="%", + diagnostic_label="Imaginary battery", + provides_fill_level_target_profile=True, + provides_leakage_behaviour=False, + provides_usage_forecast=False, + ), + ) + ) + print("Also send the target profile") + + conn.send_msg_and_await_reception_status_sync( + FRBCFillLevelTargetProfile( + message_id=uuid.uuid4(), + start_time=datetime.datetime.now(tz=datetime.timezone.utc), + elements=[ + FRBCFillLevelTargetProfileElement( + duration=Duration.from_milliseconds(30_000), + fill_level_range=NumberRange(start_of_range=20.0, end_of_range=30.0), + ), + FRBCFillLevelTargetProfileElement( + duration=Duration.from_milliseconds(300_000), + fill_level_range=NumberRange(start_of_range=40.0, end_of_range=50.0), + ), + ], + ) + ) + + print("Also send the storage status.") + conn.send_msg_and_await_reception_status_sync( + FRBCStorageStatus(message_id=uuid.uuid4(), present_fill_level=10.0) + ) + + print("Also send the actuator status.") + conn.send_msg_and_await_reception_status_sync( + FRBCActuatorStatus( + message_id=uuid.uuid4(), + actuator_id=actuator_id, + active_operation_mode_id=operation_mode_id, + operation_mode_factor=0.5, + ) + ) + + def deactivate(self, conn: S2Connection) -> None: + print("The control type FRBC is now deactivated.") + + +class MyNoControlControlType(NoControlControlType): + def activate(self, conn: S2Connection) -> None: + print("The control type NoControl is now activated.") + + def deactivate(self, conn: S2Connection) -> None: + print("The control type NoControl is now deactivated.") + + +s2_conn = S2Connection( + url="ws://localhost:8001/backend/rm/s2python-frbc/cem/dummy_model/ws", + role=EnergyManagementRole.RM, + control_types=[MyFRBCControlType(), MyNoControlControlType()], + asset_details=AssetDetails( + resource_id=str(uuid.uuid4()), + name="Some asset", + instruction_processing_delay=Duration.from_milliseconds(20), + roles=[Role(role=RoleType.ENERGY_CONSUMER, commodity=Commodity.ELECTRICITY)], + currency=Currency.EUR, + provides_forecast=False, + provides_power_measurements=[CommodityQuantity.ELECTRIC_POWER_L1], + ), +) + + +def stop(signal_num, _current_stack_frame): + print(f"Received signal {signal_num}. Will stop S2 connection.") + s2_conn.stop() + + +signal.signal(signal.SIGINT, stop) +signal.signal(signal.SIGTERM, stop) + +s2_conn.start_as_rm() diff --git a/setup.cfg b/setup.cfg index 453da1b..d2d9ca2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ install_requires = pydantic~=2.8.2 pytz click + websockets~=13.0.1 [options.packages.find] where = src diff --git a/src/s2python/common/__init__.py b/src/s2python/common/__init__.py index 6bc46f5..806de7e 100644 --- a/src/s2python/common/__init__.py +++ b/src/s2python/common/__init__.py @@ -1,5 +1,6 @@ from s2python.generated.gen_s2 import ( RoleType, + Currency, CommodityQuantity, Commodity, InstructionStatus, @@ -7,7 +8,6 @@ EnergyManagementRole, SessionRequestType, ControlType, - Currency, RevokableObjects, ) diff --git a/src/s2python/frbc/rm.py b/src/s2python/frbc/rm.py new file mode 100644 index 0000000..e69de29 diff --git a/src/s2python/reception_status_awaiter.py b/src/s2python/reception_status_awaiter.py new file mode 100644 index 0000000..5c4bd42 --- /dev/null +++ b/src/s2python/reception_status_awaiter.py @@ -0,0 +1,60 @@ +"""ReceptationStatusAwaiter class which notifies any coroutine waiting for a certain reception status message. + +Copied from +https://github.com/flexiblepower/s2-analyzer/blob/main/backend/s2_analyzer_backend/reception_status_awaiter.py under +Apache2 license on 31-08-2024. +""" + +import asyncio +import uuid +from typing import Dict + +from s2python.common import ReceptionStatus + + +class ReceptionStatusAwaiter: + received: Dict[uuid.UUID, ReceptionStatus] + awaiting: Dict[uuid.UUID, asyncio.Event] + + def __init__(self) -> None: + self.received = {} + self.awaiting = {} + + async def wait_for_reception_status( + self, message_id: uuid.UUID, timeout_reception_status: float + ) -> ReceptionStatus: + if message_id in self.received: + reception_status = self.received[message_id] + else: + if message_id in self.awaiting: + received_event = self.awaiting[message_id] + else: + received_event = asyncio.Event() + self.awaiting[message_id] = received_event + + await asyncio.wait_for(received_event.wait(), timeout_reception_status) + reception_status = self.received[message_id] + + if message_id in self.awaiting: + del self.awaiting[message_id] + + return reception_status + + async def receive_reception_status(self, reception_status: ReceptionStatus) -> None: + if not isinstance(reception_status, ReceptionStatus): + raise RuntimeError( + f"Expected a ReceptionStatus but received message {reception_status}" + ) + + if reception_status.subject_message_id in self.received: + raise RuntimeError( + f"ReceptationStatus for message_subject_id {reception_status.subject_message_id} has already " + f"been received!" + ) + + self.received[reception_status.subject_message_id] = reception_status + awaiting = self.awaiting.get(reception_status.subject_message_id) + + if awaiting: + awaiting.set() + del self.awaiting[reception_status.subject_message_id] diff --git a/src/s2python/s2_connection.py b/src/s2python/s2_connection.py new file mode 100644 index 0000000..28ac6da --- /dev/null +++ b/src/s2python/s2_connection.py @@ -0,0 +1,470 @@ +import asyncio +import json +import logging +import threading +import uuid +from dataclasses import dataclass +from typing import Optional, List, Type, Dict, Callable, Awaitable, Union + +from websockets.asyncio.client import ClientConnection as WSConnection, connect as ws_connect + +from s2python.common import ( + ReceptionStatusValues, + ReceptionStatus, + Handshake, + EnergyManagementRole, + Role, + HandshakeResponse, + ResourceManagerDetails, + Duration, + Currency, + SelectControlType, +) +from s2python.generated.gen_s2 import CommodityQuantity +from s2python.reception_status_awaiter import ReceptionStatusAwaiter +from s2python.s2_control_type import S2ControlType +from s2python.s2_parser import S2Parser +from s2python.s2_validation_error import S2ValidationError +from s2python.validate_values_mixin import S2Message +from s2python.version import S2_VERSION + +logger = logging.getLogger("s2python") + + +@dataclass +class AssetDetails: # pylint: disable=too-many-instance-attributes + resource_id: str + + provides_forecast: bool + provides_power_measurements: List[CommodityQuantity] + + instruction_processing_delay: Duration + roles: List[Role] + currency: Optional[Currency] = None + + name: Optional[str] = None + manufacturer: Optional[str] = None + model: Optional[str] = None + firmware_version: Optional[str] = None + serial_number: Optional[str] = None + + def to_resource_manager_details( + self, control_types: List[S2ControlType] + ) -> ResourceManagerDetails: + return ResourceManagerDetails( + available_control_types=[ + control_type.get_protocol_control_type() for control_type in control_types + ], + currency=self.currency, + firmware_version=self.firmware_version, + instruction_processing_delay=self.instruction_processing_delay, + manufacturer=self.manufacturer, + message_id=uuid.uuid4(), + model=self.model, + name=self.name, + provides_forecast=self.provides_forecast, + provides_power_measurement_types=self.provides_power_measurements, + resource_id=self.resource_id, + roles=self.roles, + serial_number=self.serial_number, + ) + + +S2MessageHandler = Union[ + Callable[["S2Connection", S2Message, Callable[[], None]], None], + Callable[["S2Connection", S2Message, Awaitable[None]], Awaitable[None]], +] + + +class SendOkay: + status_is_send: threading.Event + connection: "S2Connection" + subject_message_id: uuid.UUID + + def __init__(self, connection: "S2Connection", subject_message_id: uuid.UUID): + self.status_is_send = threading.Event() + self.connection = connection + self.subject_message_id = subject_message_id + + async def run_async(self) -> None: + self.status_is_send.set() + + await self.connection.respond_with_reception_status( + subject_message_id=str(self.subject_message_id), + status=ReceptionStatusValues.OK, + diagnostic_label="Processed okay.", + ) + + def run_sync(self) -> None: + self.status_is_send.set() + + self.connection.respond_with_reception_status_sync( + subject_message_id=str(self.subject_message_id), + status=ReceptionStatusValues.OK, + diagnostic_label="Processed okay.", + ) + + async def ensure_send_async(self, type_msg: Type[S2Message]) -> None: + if not self.status_is_send.is_set(): + logger.warning( + "Handler for message %s %s did not call send_okay / function to send the ReceptionStatus. " + "Sending it now.", + type_msg, + self.subject_message_id, + ) + await self.run_async() + + def ensure_send_sync(self, type_msg: Type[S2Message]) -> None: + if not self.status_is_send.is_set(): + logger.warning( + "Handler for message %s %s did not call send_okay / function to send the ReceptionStatus. " + "Sending it now.", + type_msg, + self.subject_message_id, + ) + self.run_sync() + + +class MessageHandlers: + handlers: Dict[Type[S2Message], S2MessageHandler] + + def __init__(self) -> None: + self.handlers = {} + + async def handle_message(self, connection: "S2Connection", msg: S2Message) -> None: + """Handle the S2 message using the registered handler. + + :param connection: The S2 conncetion the `msg` is received from. + :param msg: The S2 message + """ + handler = self.handlers.get(type(msg)) + if handler is not None: + send_okay = SendOkay(connection, msg.message_id) # type: ignore[attr-defined] + + try: + if asyncio.iscoroutinefunction(handler): + await handler(connection, msg, send_okay.run_async()) # type: ignore[arg-type] + await send_okay.ensure_send_async(type(msg)) + else: + + def do_message() -> None: + handler(connection, msg, send_okay.run_sync) # type: ignore[arg-type] + send_okay.ensure_send_sync(type(msg)) + + eventloop = asyncio.get_event_loop() + await eventloop.run_in_executor(executor=None, func=do_message) + except Exception: + if not send_okay.status_is_send.is_set(): + await connection.respond_with_reception_status( + subject_message_id=str(msg.message_id), # type: ignore[attr-defined] + status=ReceptionStatusValues.PERMANENT_ERROR, + diagnostic_label=f"While processing message {msg.message_id} " # type: ignore[attr-defined] + f"an unrecoverable error occurred.", + ) + raise + else: + logger.warning( + "Received a message of type %s but no handler is registered. Ignoring the message.", + type(msg), + ) + + def register_handler(self, msg_type: Type[S2Message], handler: S2MessageHandler) -> None: + """Register a coroutine function or a normal function as the handler for a specific S2 message type. + + :param msg_type: The S2 message type to attach the handler to. + :param handler: The function (asynchronuous or normal) which should handle the S2 message. + """ + self.handlers[msg_type] = handler + + +class S2Connection: # pylint: disable=too-many-instance-attributes + url: str + reception_status_awaiter: ReceptionStatusAwaiter + ws: Optional[WSConnection] + s2_parser: S2Parser + control_types: List[S2ControlType] + role: EnergyManagementRole + asset_details: AssetDetails + + _thread: threading.Thread + + _handlers: MessageHandlers + _current_control_type: Optional[S2ControlType] + _received_messages: asyncio.Queue + + _eventloop: asyncio.AbstractEventLoop + _background_tasks: Optional[asyncio.Task] + _stop_event: asyncio.Event + + def __init__( + self, + url: str, + role: EnergyManagementRole, + control_types: List[S2ControlType], + asset_details: AssetDetails, + ) -> None: + self.url = url + self.reception_status_awaiter = ReceptionStatusAwaiter() + self.s2_parser = S2Parser() + + self._handlers = MessageHandlers() + self._current_control_type = None + + self._eventloop = asyncio.new_event_loop() + self._background_tasks = None + + self.control_types = control_types + self.role = role + self.asset_details = asset_details + + self._handlers.register_handler(SelectControlType, self.handle_select_control_type_as_rm) + self._handlers.register_handler(Handshake, self.handle_handshake) + self._handlers.register_handler(HandshakeResponse, self.handle_handshake_response_as_rm) + + def start_as_rm(self) -> None: + self._thread = threading.Thread(target=self._run_eventloop) + self._thread.start() + logger.debug("Started eventloop thread!") + + def _run_eventloop(self) -> None: + logger.debug("Starting eventloop") + self._eventloop.run_until_complete(self._run_as_rm()) + + def stop(self) -> None: + """Stops the S2 connection. + + Note: Ensure this method is called from a different thread than the thread running the S2 connection. + Otherwise it will block waiting on the coroutine _do_stop to terminate successfully but it can't run + the coroutine. A `RuntimeError` will be raised to prevent the indefinite block. + """ + if threading.current_thread() == self._thread: + raise RuntimeError( + "Do not call stop from the thread running the S2 connection. This results in an " + "infinite block!" + ) + + asyncio.run_coroutine_threadsafe(self._do_stop(), self._eventloop).result() + + async def _do_stop(self) -> None: + logger.info("Will stop the S2 connection.") + if self._background_tasks: + self._background_tasks.cancel() + self._background_tasks = None + + if self.ws: + await self.ws.close() + await self.ws.wait_closed() + + async def _run_as_rm(self) -> None: + logger.debug("Connecting as S2 resource manager.") + self._received_messages = asyncio.Queue() + await self.connect_ws() + + self._background_tasks = self._eventloop.create_task( + asyncio.wait( + (self._receive_messages(), self._handle_received_messages()), + return_when=asyncio.FIRST_EXCEPTION, + ) + ) + + await self.connect_as_rm() + done: List[asyncio.Task] + pending: List[asyncio.Task] + (done, pending) = await self._background_tasks + + for task in done: + task.result() + + for task in pending: + task.cancel() + + async def connect_ws(self) -> None: + self.ws = await ws_connect(uri=self.url) + + async def connect_as_rm(self) -> None: + await self.send_msg_and_await_reception_status_async( + Handshake( + message_id=uuid.uuid4(), role=self.role, supported_protocol_versions=[S2_VERSION] + ) + ) + logger.debug("Send handshake to CEM. Expecting Handshake and HandshakeResponse from CEM.") + + async def handle_handshake( + self, _: "S2Connection", message: S2Message, send_okay: Awaitable[None] + ) -> None: + if not isinstance(message, Handshake): + logger.error( + "Handler for Handshake received a message of the wrong type: %s", type(message) + ) + return + + logger.debug( + "%s supports S2 protocol versions: %s", + message.role, + message.supported_protocol_versions, + ) + await send_okay + + async def handle_handshake_response_as_rm( + self, _: "S2Connection", message: S2Message, send_okay: Awaitable[None] + ) -> None: + if not isinstance(message, HandshakeResponse): + logger.error( + "Handler for HandshakeResponse received a message of the wrong type: %s", + type(message), + ) + return + + logger.debug("Received HandshakeResponse %s", message.to_json()) + + logger.debug("CEM selected to use version %s", message.selected_protocol_version) + await send_okay + logger.debug("Handshake complete. Sending first ResourceManagerDetails.") + + await self.send_msg_and_await_reception_status_async( + self.asset_details.to_resource_manager_details(self.control_types) + ) + + async def handle_select_control_type_as_rm( + self, _: "S2Connection", message: S2Message, send_okay: Awaitable[None] + ) -> None: + if not isinstance(message, SelectControlType): + logger.error( + "Handler for SelectControlType received a message of the wrong type: %s", + type(message), + ) + return + + await send_okay + + logger.debug("CEM selected control type %s. Activating control type.", message.control_type) + + control_types_by_protocol_name = { + c.get_protocol_control_type(): c for c in self.control_types + } + selected_control_type: Optional[S2ControlType] = control_types_by_protocol_name.get( + message.control_type + ) + + if self._current_control_type is not None: + await self._eventloop.run_in_executor(None, self._current_control_type.deactivate, self) + + self._current_control_type = selected_control_type + + if self._current_control_type is not None: + await self._eventloop.run_in_executor(None, self._current_control_type.activate, self) + self._current_control_type.register_handlers(self._handlers) + + async def _receive_messages(self) -> None: + """Receives all incoming messages in the form of a generator. + + Will also receive the ReceptionStatus messages but instead of yielding these messages, they are routed + to any calls of `send_msg_and_await_reception_status`. + """ + if self.ws is None: + raise RuntimeError( + "Cannot receive messages if websocket connection is not yet established." + ) + + logger.info("S2 connection has started to receive messages.") + + async for message in self.ws: + try: + s2_msg: S2Message = self.s2_parser.parse_as_any_message(message) + except json.JSONDecodeError: + await self._send_and_forget( + ReceptionStatus( + subject_message_id="00000000-0000-0000-0000-000000000000", + status=ReceptionStatusValues.INVALID_DATA, + diagnostic_label="Not valid json.", + ) + ) + except S2ValidationError as e: + json_msg = json.loads(message) + message_id = json_msg.get("message_id") + if message_id: + await self.respond_with_reception_status( + subject_message_id=message_id, + status=ReceptionStatusValues.INVALID_MESSAGE, + diagnostic_label=str(e), + ) + else: + await self.respond_with_reception_status( + subject_message_id="00000000-0000-0000-0000-000000000000", + status=ReceptionStatusValues.INVALID_DATA, + diagnostic_label="Message appears valid json but could not find a message_id field.", + ) + else: + logger.debug("Received message %s", s2_msg.to_json()) + + if isinstance(s2_msg, ReceptionStatus): + logger.debug( + "Message is a reception status for %s so registering in cache.", + s2_msg.subject_message_id, + ) + await self.reception_status_awaiter.receive_reception_status(s2_msg) + else: + await self._received_messages.put(s2_msg) + + async def _send_and_forget(self, s2_msg: S2Message) -> None: + if self.ws is None: + raise RuntimeError( + "Cannot send messages if websocket connection is not yet established." + ) + + json_msg = s2_msg.to_json() + logger.debug("Sending message %s", json_msg) + await self.ws.send(json_msg) + + async def respond_with_reception_status( + self, subject_message_id: str, status: ReceptionStatusValues, diagnostic_label: str + ) -> None: + logger.debug("Responding to message %s with status %s", subject_message_id, status) + await self._send_and_forget( + ReceptionStatus( + subject_message_id=subject_message_id, + status=status, + diagnostic_label=diagnostic_label, + ) + ) + + def respond_with_reception_status_sync( + self, subject_message_id: str, status: ReceptionStatusValues, diagnostic_label: str + ) -> None: + asyncio.run_coroutine_threadsafe( + self.respond_with_reception_status(subject_message_id, status, diagnostic_label), + self._eventloop, + ).result() + + async def send_msg_and_await_reception_status_async( + self, s2_msg: S2Message, timeout_reception_status: float = 5.0, raise_on_error: bool = True + ) -> ReceptionStatus: + await self._send_and_forget(s2_msg) + logger.debug( + "Waiting for ReceptionStatus for %s %s seconds", + s2_msg.message_id, # type: ignore[attr-defined] + timeout_reception_status, + ) + reception_status = await self.reception_status_awaiter.wait_for_reception_status( + s2_msg.message_id, timeout_reception_status # type: ignore[attr-defined] + ) + + if reception_status.status != ReceptionStatusValues.OK and raise_on_error: + raise RuntimeError(f"ReceptionStatus was not OK but rather {reception_status.status}") + + return reception_status + + def send_msg_and_await_reception_status_sync( + self, s2_msg: S2Message, timeout_reception_status: float = 5.0, raise_on_error: bool = True + ) -> ReceptionStatus: + return asyncio.run_coroutine_threadsafe( + self.send_msg_and_await_reception_status_async( + s2_msg, timeout_reception_status, raise_on_error + ), + self._eventloop, + ).result() + + async def _handle_received_messages(self) -> None: + while True: + msg = await self._received_messages.get() + await self._handlers.handle_message(self, msg) diff --git a/src/s2python/s2_control_type.py b/src/s2python/s2_control_type.py new file mode 100644 index 0000000..f9a4545 --- /dev/null +++ b/src/s2python/s2_control_type.py @@ -0,0 +1,56 @@ +import abc +import typing + +from s2python.common import ControlType as ProtocolControlType +from s2python.frbc import FRBCInstruction +from s2python.validate_values_mixin import S2Message + +if typing.TYPE_CHECKING: + from s2python.s2_connection import S2Connection, MessageHandlers + + +class S2ControlType(abc.ABC): + @abc.abstractmethod + def get_protocol_control_type(self) -> ProtocolControlType: ... + + @abc.abstractmethod + def register_handlers(self, handlers: "MessageHandlers") -> None: ... + + @abc.abstractmethod + def activate(self, conn: "S2Connection") -> None: ... + + @abc.abstractmethod + def deactivate(self, conn: "S2Connection") -> None: ... + + +class FRBCControlType(S2ControlType): + def get_protocol_control_type(self) -> ProtocolControlType: + return ProtocolControlType.FILL_RATE_BASED_CONTROL + + def register_handlers(self, handlers: "MessageHandlers") -> None: + handlers.register_handler(FRBCInstruction, self.handle_instruction) + + @abc.abstractmethod + def handle_instruction( + self, conn: "S2Connection", msg: S2Message, send_okay: typing.Callable[[], None] + ) -> None: ... + + @abc.abstractmethod + def activate(self, conn: "S2Connection") -> None: ... + + @abc.abstractmethod + def deactivate(self, conn: "S2Connection") -> None: ... + + +class NoControlControlType(S2ControlType): + def get_protocol_control_type(self) -> ProtocolControlType: + return ProtocolControlType.NOT_CONTROLABLE + + def register_handlers(self, handlers: "MessageHandlers") -> None: + pass + + @abc.abstractmethod + def activate(self, conn: "S2Connection") -> None: ... + + @abc.abstractmethod + def deactivate(self, conn: "S2Connection") -> None: ... diff --git a/src/s2python/s2_parser.py b/src/s2python/s2_parser.py index 99433b6..906a286 100644 --- a/src/s2python/s2_parser.py +++ b/src/s2python/s2_parser.py @@ -24,18 +24,18 @@ FRBCTimerStatus, FRBCUsageForecast, ) -from s2python.validate_values_mixin import SupportsValidation +from s2python.validate_values_mixin import S2Message from s2python.s2_validation_error import S2ValidationError LOGGER = logging.getLogger(__name__) S2MessageType = str -M = TypeVar("M", bound=SupportsValidation) +M = TypeVar("M", bound=S2Message) # May be generated with development_utilities/generate_s2_message_type_to_class.py -TYPE_TO_MESSAGE_CLASS: Dict[str, Type[SupportsValidation]] = { +TYPE_TO_MESSAGE_CLASS: Dict[str, Type[S2Message]] = { "FRBC.ActuatorStatus": FRBCActuatorStatus, "FRBC.FillLevelTargetProfile": FRBCFillLevelTargetProfile, "FRBC.Instruction": FRBCInstruction, @@ -59,13 +59,13 @@ class S2Parser: @staticmethod - def _parse_json_if_required(unparsed_message: Union[dict, str]) -> dict: - if isinstance(unparsed_message, str): + def _parse_json_if_required(unparsed_message: Union[dict, str, bytes]) -> dict: + if isinstance(unparsed_message, (str, bytes)): return json.loads(unparsed_message) return unparsed_message @staticmethod - def parse_as_any_message(unparsed_message: Union[dict, str]) -> SupportsValidation: + def parse_as_any_message(unparsed_message: Union[dict, str, bytes]) -> S2Message: """Parse the message as any S2 python message regardless of message type. :param unparsed_message: The message as a JSON-formatted string or as a json-parsed dictionary. @@ -77,13 +77,16 @@ def parse_as_any_message(unparsed_message: Union[dict, str]) -> SupportsValidati if message_type not in TYPE_TO_MESSAGE_CLASS: raise S2ValidationError( - None, message_json, f"Unable to parse {message_type} as an S2 message. Type unknown.", None + None, + message_json, + f"Unable to parse {message_type} as an S2 message. Type unknown.", + None, ) return TYPE_TO_MESSAGE_CLASS[message_type].model_validate(message_json) @staticmethod - def parse_as_message(unparsed_message: Union[dict, str], as_message: Type[M]) -> M: + def parse_as_message(unparsed_message: Union[dict, str, bytes], as_message: Type[M]) -> M: """Parse the message to a specific S2 python message. :param unparsed_message: The message as a JSON-formatted string or as a JSON-parsed dictionary. @@ -95,7 +98,7 @@ def parse_as_message(unparsed_message: Union[dict, str], as_message: Type[M]) -> return as_message.from_dict(message_json) @staticmethod - def parse_message_type(unparsed_message: Union[dict, str]) -> Optional[S2MessageType]: + def parse_message_type(unparsed_message: Union[dict, str, bytes]) -> Optional[S2MessageType]: """Parse only the message type from the unparsed message. This is useful to call before `parse_as_message` to retrieve the message type and allows for strictly-typed diff --git a/src/s2python/validate_values_mixin.py b/src/s2python/validate_values_mixin.py index c59cc7f..7d0d9d6 100644 --- a/src/s2python/validate_values_mixin.py +++ b/src/s2python/validate_values_mixin.py @@ -1,25 +1,6 @@ -from typing import ( - TypeVar, - Generic, - Protocol, - Type, - Optional, - Callable, - Any, - Union, - AbstractSet, - Mapping, - List, - Dict, - Literal, -) -from typing_extensions import Self - -from pydantic import ( # pylint: disable=no-name-in-module - BaseModel, - ValidationError, -) -from pydantic.main import IncEx +from typing import TypeVar, Generic, Type, Callable, Any, Union, AbstractSet, Mapping, List, Dict + +from pydantic import BaseModel, ValidationError # pylint: disable=no-name-in-module from pydantic.v1.error_wrappers import display_errors # pylint: disable=no-name-in-module from s2python.s2_validation_error import S2ValidationError @@ -31,80 +12,17 @@ MappingIntStrAny = Mapping[IntStr, Any] -class SupportsValidation(Protocol[B_co]): - # ValidateValuesMixin methods - def to_json(self) -> str: ... - - def to_dict(self) -> Dict: ... - - @classmethod - def from_json(cls, json_str: str) -> B_co: ... - - @classmethod - def from_dict(cls, json_dict: Dict) -> B_co: ... +C = TypeVar("C", bound="BaseModel") - # Pydantic methods - @classmethod - def model_validate_json( - cls, - json_data: Union[str, bytes, bytearray], - *, - strict: Optional[bool] = None, - context: Optional[Any] = None, - ) -> Self: ... - @classmethod - def model_validate( - cls, - obj: Any, - *, - strict: Optional[bool] = None, - from_attributes: Optional[bool] = None, - context: Optional[Any] = None, - ) -> Self: ... - - def model_dump( # pylint: disable=too-many-arguments - self, - *, - mode: Union[Literal["json", "python"], str] = "python", - include: IncEx = None, - exclude: IncEx = None, - context: Optional[Any] = None, - by_alias: bool = False, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: Union[bool, Literal["none", "warn", "error"]] = True, - serialize_as_any: bool = False, - ) -> Dict[str, Any]: ... - - def model_dump_json( # pylint: disable=too-many-arguments - self, - *, - indent: Optional[int] = None, - include: IncEx = None, - exclude: IncEx = None, - context: Optional[Any] = None, - by_alias: bool = False, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: Union[bool, Literal["none", "warn", "error"]] = True, - serialize_as_any: bool = False, - ) -> str: ... - - -C = TypeVar("C", bound="SupportsValidation") - - -class S2Message(Generic[C]): +class S2Message(BaseModel, Generic[C]): def to_json(self: C) -> str: try: return self.model_dump_json(by_alias=True, exclude_none=True) except (ValidationError, TypeError) as e: - raise S2ValidationError(type(self), self, "Pydantic raised a format validation error.", e) from e + raise S2ValidationError( + type(self), self, "Pydantic raised a format validation error.", e + ) from e def to_dict(self: C) -> Dict: return self.model_dump() @@ -141,9 +59,7 @@ def inner(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return inner -def catch_and_convert_exceptions( - input_class: Type[SupportsValidation[B_co]], -) -> Type[SupportsValidation[B_co]]: +def catch_and_convert_exceptions(input_class: Type[S2Message[B_co]]) -> Type[S2Message[B_co]]: input_class.__init__ = convert_to_s2exception(input_class.__init__) # type: ignore[method-assign] input_class.__setattr__ = convert_to_s2exception(input_class.__setattr__) # type: ignore[method-assign] input_class.model_validate_json = convert_to_s2exception( # type: ignore[method-assign] diff --git a/src/s2python/version.py b/src/s2python/version.py index 6c5007c..3789fe8 100644 --- a/src/s2python/version.py +++ b/src/s2python/version.py @@ -1 +1,3 @@ VERSION = "0.2.0" + +S2_VERSION = "0.0.2-beta" diff --git a/tests/unit/reception_status_awaiter_test.py b/tests/unit/reception_status_awaiter_test.py new file mode 100644 index 0000000..167966d --- /dev/null +++ b/tests/unit/reception_status_awaiter_test.py @@ -0,0 +1,164 @@ +"""Tests for ReceptionStatusAwaiter. + +Copied from +https://github.com/flexiblepower/s2-analyzer/blob/main/backend/test/s2_analyzer_backend/reception_status_awaiter_test.py +under Apache2 license on 31-08-2024. +""" + +import asyncio +import datetime +import uuid +from unittest import IsolatedAsyncioTestCase + +from s2python.common import ( + ReceptionStatus, + ReceptionStatusValues, + InstructionStatus, + InstructionStatusUpdate, +) +from s2python.reception_status_awaiter import ReceptionStatusAwaiter + + +class ReceptionStatusAwaiterTest(IsolatedAsyncioTestCase): + async def test__wait_for_reception_status__receive_while_waiting(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + message_id = uuid.uuid4() + s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) + + # Act + wait_task = asyncio.create_task(awaiter.wait_for_reception_status(message_id, 1.0)) + should_be_waiting_still = not wait_task.done() + await awaiter.receive_reception_status(s2_reception_status) + await wait_task + received_s2_reception_status = wait_task.result() + + # Assert + expected_s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) + + self.assertTrue(should_be_waiting_still) + self.assertEqual(expected_s2_reception_status, received_s2_reception_status) + + async def test__wait_for_reception_status__already_received(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + message_id = uuid.uuid4() + s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) + + # Act + await awaiter.receive_reception_status(s2_reception_status) + received_s2_reception_status = await awaiter.wait_for_reception_status(message_id, 1.0) + + # Assert + expected_s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) + self.assertEqual(expected_s2_reception_status, received_s2_reception_status) + + async def test__wait_for_reception_status__multiple_receive_while_waiting(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + message_id = uuid.uuid4() + s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) + + # Act + wait_task_1 = asyncio.create_task(awaiter.wait_for_reception_status(message_id, 1.0)) + wait_task_2 = asyncio.create_task(awaiter.wait_for_reception_status(message_id, 1.0)) + should_be_waiting_still_1 = not wait_task_1.done() + should_be_waiting_still_2 = not wait_task_2.done() + await awaiter.receive_reception_status(s2_reception_status) + await wait_task_1 + await wait_task_2 + received_s2_reception_status_1 = wait_task_1.result() + received_s2_reception_status_2 = wait_task_2.result() + + # Assert + expected_s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) + + self.assertTrue(should_be_waiting_still_1) + self.assertTrue(should_be_waiting_still_2) + self.assertEqual(expected_s2_reception_status, received_s2_reception_status_1) + self.assertEqual(expected_s2_reception_status, received_s2_reception_status_2) + + async def test__receive_reception_status__wrong_message(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + s2_msg = InstructionStatusUpdate( + message_id=uuid.uuid4(), + instruction_id=uuid.uuid4(), + status_type=InstructionStatus.NEW, + timestamp=datetime.datetime.now(datetime.timezone.utc), + ) + + # Act / Assert + with self.assertRaises(RuntimeError): + await awaiter.receive_reception_status(s2_msg) # type: ignore[arg-type] + + async def test__receive_reception_status__received_duplicate(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + s2_reception_status = ReceptionStatus( + subject_message_id=uuid.uuid4(), status=ReceptionStatusValues.OK + ) + + # Act / Assert + await awaiter.receive_reception_status(s2_reception_status) + with self.assertRaises(RuntimeError): + await awaiter.receive_reception_status(s2_reception_status) + + async def test__receive_reception_status__receive_no_awaiting(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + message_id = uuid.uuid4() + s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) + + # Act + await awaiter.receive_reception_status(s2_reception_status) + + # Assert + expected_received = { + message_id: ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) + } + self.assertEqual(awaiter.received, expected_received) + self.assertEqual(awaiter.awaiting, {}) + + async def test__receive_reception_status__receive_with_awaiting(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + awaiting_event = asyncio.Event() + message_id = uuid.uuid4() + awaiter.awaiting = {message_id: awaiting_event} + s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) + + # Act + should_not_be_set = not awaiting_event.is_set() + await awaiter.receive_reception_status(s2_reception_status) + should_be_set = awaiting_event.is_set() + + # Assert + expected_received = { + message_id: ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) + } + + self.assertTrue(should_not_be_set) + self.assertTrue(should_be_set) + self.assertEqual(awaiter.received, expected_received) + self.assertEqual(awaiter.awaiting, {}) diff --git a/tests/unit/s2_connection_test.py b/tests/unit/s2_connection_test.py new file mode 100644 index 0000000..fcb8b37 --- /dev/null +++ b/tests/unit/s2_connection_test.py @@ -0,0 +1,65 @@ +# import unittest +# +# +# class S2ConnectionTest(unittest.TestCase): +# async def test__send_and_await_reception_status__receive_while_waiting(self): +# # Arrange +# conn = Mock() +# awaiter = ReceptionStatusAwaiter() +# message_id = "1" +# s2_message = { +# "message_type": "Handshake", +# "message_id": message_id, +# "role": "RM", +# "supported_protocol_versions": ["1.0"], +# } +# s2_reception_status = { +# "message_type": "ReceptionStatus", +# "subject_message_id": message_id, +# "status": "OK", +# } +# +# # Act +# wait_task = asyncio.create_task( +# awaiter.send_and_await_reception_status(conn, s2_message, True) +# ) +# should_be_waiting_still = not wait_task.done() +# await awaiter.receive_reception_status(s2_reception_status) +# await wait_task +# received_s2_reception_status = wait_task.result() +# +# # Assert +# expected_s2_reception_status = { +# "message_type": "ReceptionStatus", +# "subject_message_id": "1", +# "status": "OK", +# } +# +# self.assertTrue(should_be_waiting_still) +# self.assertEqual(expected_s2_reception_status, received_s2_reception_status) +# +# async def test__send_and_await_reception_status__receive_while_waiting_not_okay(self): +# # Arrange +# conn = Mock() +# awaiter = ReceptionStatusAwaiter() +# message_id = "1" +# s2_message = { +# "message_type": "Handshake", +# "message_id": message_id, +# "role": "RM", +# "supported_protocol_versions": ["1.0"], +# } +# s2_reception_status = { +# "message_type": "ReceptionStatus", +# "subject_message_id": message_id, +# "status": "INVALID_MESSAGE", +# } +# +# # Act / Assert +# wait_task = asyncio.create_task( +# awaiter.send_and_await_reception_status(conn, s2_message, True) +# ) +# await awaiter.receive_reception_status(s2_reception_status) +# +# with self.assertRaises(RuntimeError): +# await wait_task