From 299fd45c54f9650ee54ad9aa35a0db0503350e1c Mon Sep 17 00:00:00 2001 From: Sebastiaan la Fleur Date: Wed, 4 Sep 2024 20:24:22 +0200 Subject: [PATCH 1/4] add_conn: Draft version of generic s2 connection with frbc and no control resource manager handlers. --- dev-requirements.txt | 167 +++++--------- setup.cfg | 1 + src/s2python/common/__init__.py | 1 + src/s2python/example_frbc_rm.py | 34 +++ src/s2python/frbc/rm.py | 0 src/s2python/reception_status_awaiter.py | 60 +++++ src/s2python/s2_connection.py | 240 ++++++++++++++++++++ src/s2python/s2_control_type.py | 53 +++++ src/s2python/validate_values_mixin.py | 135 +++++------ tests/unit/reception_status_awaiter_test.py | 224 ++++++++++++++++++ 10 files changed, 743 insertions(+), 172 deletions(-) create mode 100644 src/s2python/example_frbc_rm.py create mode 100644 src/s2python/frbc/rm.py create mode 100644 src/s2python/reception_status_awaiter.py create mode 100644 src/s2python/s2_connection.py create mode 100644 src/s2python/s2_control_type.py create mode 100644 tests/unit/reception_status_awaiter_test.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 2ef5346..1b55bb4 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -6,29 +6,25 @@ # alabaster==0.7.13 # via sphinx -argcomplete==3.1.2 +argcomplete==3.5.0 # via datamodel-code-generator -astroid==3.0.1 +astroid==3.2.4 # via pylint -attrs==23.1.0 - # via jsonschema -babel==2.13.1 +babel==2.16.0 # via sphinx -black==23.10.1 +black==24.8.0 # via datamodel-code-generator -build==1.0.3 +build==1.2.1 # via pip-tools -cachetools==5.3.2 +cachetools==5.5.0 # via tox -certifi==2023.7.22 +certifi==2024.8.30 # via requests cfgv==3.4.0 # via pre-commit chardet==5.2.0 - # via - # prance - # tox -charset-normalizer==3.3.1 + # via tox +charset-normalizer==3.3.2 # via requests click==8.1.7 # via @@ -37,171 +33,133 @@ click==8.1.7 # s2-python (setup.cfg) colorama==0.4.6 # via tox -coverage[toml]==7.3.2 +coverage[toml]==7.6.1 # via pytest-cov -datamodel-code-generator==0.22.1 +datamodel-code-generator==0.25.9 # via s2-python (setup.cfg) -dill==0.3.7 +dill==0.3.8 # via pylint -distlib==0.3.7 +distlib==0.3.8 # via virtualenv -dnspython==2.4.2 +dnspython==2.6.1 # via email-validator -docutils==0.18.1 +docutils==0.20.1 # via # sphinx # sphinx-rtd-theme # sphinx-tabs -email-validator==2.1.0.post1 +email-validator==2.2.0 # via pydantic -exceptiongroup==1.1.3 +exceptiongroup==1.2.2 # via pytest -filelock==3.13.0 +filelock==3.15.4 # via # tox # virtualenv -genson==1.2.2 +genson==1.3.0 # via datamodel-code-generator -identify==2.5.31 +identify==2.6.0 # via pre-commit -idna==3.4 +idna==3.8 # via # email-validator # requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 +importlib-metadata==8.4.0 # via # build # sphinx -importlib-resources==5.13.0 - # via - # jsonschema - # openapi-spec-validator inflect==5.6.2 # via datamodel-code-generator iniconfig==2.0.0 # via pytest -isort==5.12.0 +isort==5.13.2 # via # datamodel-code-generator # pylint -jinja2==3.1.2 +jinja2==3.1.4 # via # datamodel-code-generator # sphinx -jsonschema==4.17.3 - # via - # jsonschema-spec - # openapi-schema-validator - # openapi-spec-validator -jsonschema-spec==0.1.6 - # via openapi-spec-validator -lazy-object-proxy==1.9.0 - # via openapi-spec-validator -markupsafe==2.1.3 +markupsafe==2.1.5 # via jinja2 mccabe==0.7.0 # via pylint -mypy==1.6.1 +mypy==1.11.2 # via s2-python (setup.cfg) mypy-extensions==1.0.0 # via # black # mypy -nodeenv==1.8.0 +nodeenv==1.9.1 # via pre-commit -openapi-schema-validator==0.4.4 - # via openapi-spec-validator -openapi-spec-validator==0.5.7 - # via datamodel-code-generator -packaging==23.2 +packaging==24.1 # via # black # build # datamodel-code-generator - # prance # pyproject-api # pytest # sphinx # tox -pathable==0.4.3 - # via jsonschema-spec -pathspec==0.11.2 +pathspec==0.12.1 # via black -pip-tools==7.3.0 +pip-tools==7.4.1 # via s2-python (setup.cfg) -pkgutil-resolve-name==1.3.10 - # via jsonschema -platformdirs==3.11.0 +platformdirs==4.2.2 # via # black # pylint # tox # virtualenv -pluggy==1.3.0 +pluggy==1.5.0 # via # pytest # tox -prance==23.6.21.0 - # via datamodel-code-generator pre-commit==3.5.0 # via s2-python (setup.cfg) -pydantic[email]==1.10.13 +pydantic[email]==1.10.18 # via # datamodel-code-generator # s2-python (setup.cfg) -pygments==2.16.1 +pygments==2.18.0 # via # sphinx # sphinx-tabs -pylint==3.0.2 +pylint==3.2.6 # via s2-python (setup.cfg) -pyproject-api==1.6.1 +pyproject-api==1.7.1 # via tox -pyproject-hooks==1.0.0 - # via build -pyrsistent==0.20.0 - # via jsonschema -pysnooper==1.2.0 - # via datamodel-code-generator -pytest==7.4.3 +pyproject-hooks==1.1.0 + # via + # build + # pip-tools +pytest==8.3.2 # via # pytest-cov # pytest-timer # s2-python (setup.cfg) -pytest-cov==4.1.0 +pytest-cov==5.0.0 # via pytest-cover pytest-cover==3.0.0 # via pytest-coverage pytest-coverage==0.0 # via s2-python (setup.cfg) -pytest-timer==0.0.11 +pytest-timer==1.0.0 # via s2-python (setup.cfg) -pytz==2023.3.post1 +pytz==2024.1 # via # babel # s2-python (setup.cfg) -pyyaml==6.0.1 +pyyaml==6.0.2 # via - # jsonschema-spec + # datamodel-code-generator # pre-commit -requests==2.31.0 - # via - # jsonschema-spec - # prance - # sphinx -rfc3339-validator==0.1.4 - # via openapi-schema-validator -ruamel-yaml==0.18.3 - # via prance -ruamel-yaml-clib==0.2.8 - # via ruamel-yaml +requests==2.32.3 + # via sphinx six==1.16.0 - # via - # prance - # rfc3339-validator - # sphinxcontrib-httpdomain + # via sphinxcontrib-httpdomain snowballstemmer==2.2.0 # via sphinx sphinx==7.1.2 @@ -217,9 +175,9 @@ sphinx-copybutton==0.5.2 # via s2-python (setup.cfg) sphinx-fontawesome==0.0.6 # via s2-python (setup.cfg) -sphinx-rtd-theme==1.3.0 +sphinx-rtd-theme==2.0.0 # via s2-python (setup.cfg) -sphinx-tabs==3.4.4 +sphinx-tabs==3.4.5 # via s2-python (setup.cfg) sphinxcontrib-applehelp==1.0.4 # via sphinx @@ -248,34 +206,33 @@ tomli==2.0.1 # pip-tools # pylint # pyproject-api - # pyproject-hooks # pytest # tox -tomlkit==0.12.1 +tomlkit==0.13.2 # via pylint -tox==4.11.3 +tox==4.18.0 # via s2-python (setup.cfg) -types-pytz==2023.3.1.1 +types-pytz==2024.1.0.20240417 # via s2-python (setup.cfg) -typing-extensions==4.8.0 +typing-extensions==4.12.2 # via # astroid # black # mypy # pydantic # pylint -urllib3==2.0.7 +urllib3==2.2.2 # via requests -virtualenv==20.24.6 +virtualenv==20.26.3 # via # pre-commit # tox -wheel==0.41.3 +websockets==13.0.1 + # via s2-python (setup.cfg) +wheel==0.44.0 # via pip-tools -zipp==3.17.0 - # via - # importlib-metadata - # importlib-resources +zipp==3.20.1 + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/setup.cfg b/setup.cfg index d861495..0cbc265 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,6 +41,7 @@ install_requires = pydantic~=1.10.7 pytz click + websockets~=13.0.1 [options.packages.find] where = src diff --git a/src/s2python/common/__init__.py b/src/s2python/common/__init__.py index 6bc46f5..4f099ed 100644 --- a/src/s2python/common/__init__.py +++ b/src/s2python/common/__init__.py @@ -1,5 +1,6 @@ from s2python.generated.gen_s2 import ( RoleType, + Currency, CommodityQuantity, Commodity, InstructionStatus, diff --git a/src/s2python/example_frbc_rm.py b/src/s2python/example_frbc_rm.py new file mode 100644 index 0000000..c8efd0a --- /dev/null +++ b/src/s2python/example_frbc_rm.py @@ -0,0 +1,34 @@ +import uuid + +from s2python.common import EnergyManagementRole, Duration, Role, RoleType, Commodity, Currency +from s2python.frbc import FRBCInstruction, FRBCSystemDescription +from s2python.s2_connection import S2Connection, AssetDetails +from s2python.s2_control_type import FRBCControlType + + +class MyFRBCControlType(FRBCControlType): + def handle_instruction(self, conn: S2Connection, msg: FRBCInstruction) -> None: + print(f"I have received the message {msg} from {conn}") + + def activate(self) -> None: + print("It is now activated.") + + def deactivate(self) -> None: + print("It is now deactivated.") + + +s2_conn = S2Connection( + url="http://cem_is_here.com:8080/", + role=EnergyManagementRole.RM, + control_types=[MyFRBCControlType()], + asset_details=AssetDetails( + resource_id=str(uuid.uuid4()), + name="Some asset", + instruction_processing_delay=Duration.from_milliseconds(20), + roles=[Role(role=RoleType.ENERGY_CONSUMER, commodity=Commodity.ELECTRICITY)], + currency=Currency.EUR, + ), +) + +s2_conn.start_as_rm() +s2_conn.send_msg_and_await_reception_status(FRBCSystemDescription(...)) diff --git a/src/s2python/frbc/rm.py b/src/s2python/frbc/rm.py new file mode 100644 index 0000000..e69de29 diff --git a/src/s2python/reception_status_awaiter.py b/src/s2python/reception_status_awaiter.py new file mode 100644 index 0000000..824aa2e --- /dev/null +++ b/src/s2python/reception_status_awaiter.py @@ -0,0 +1,60 @@ +"""ReceptationStatusAwaiter class which notifies any coroutine waiting for a certain reception status message. + +Copied from https://github.com/flexiblepower/s2-analyzer/blob/main/backend/s2_analyzer_backend/reception_status_awaiter.py under Apache2 license on 31-08-2024. +""" + +import asyncio +import uuid +from typing import Dict + +from s2python.common import ReceptionStatus +from s2python.validate_values_mixin import S2Message + + +class ReceptionStatusAwaiter: + received: Dict[uuid.UUID, S2Message] + awaiting: Dict[uuid.UUID, asyncio.Event] + + def __init__(self): + self.received = {} + self.awaiting = {} + + async def wait_for_reception_status( + self, message_id: uuid.UUID, timeout_reception_status: float + ) -> ReceptionStatus: + # TODO Add timeout + if message_id in self.received: + reception_status = self.received[message_id] + else: + if message_id in self.awaiting: + received_event = self.awaiting[message_id] + else: + received_event = asyncio.Event() + self.awaiting[message_id] = received_event + + async with asyncio.timeout(timeout_reception_status): + await received_event.wait() + reception_status = self.received.get(message_id) + + if message_id in self.awaiting: + del self.awaiting[message_id] + + return reception_status + + async def receive_reception_status(self, reception_status: ReceptionStatus) -> None: + if reception_status.get("message_type") != "ReceptionStatus": + raise RuntimeError( + f"Expected a ReceptionStatus but received message {reception_status}" + ) + message_id = reception_status["subject_message_id"] + + if reception_status.subject_message_id in self.received: + raise RuntimeError( + f"ReceptationStatus for message_subject_id {message_id} has already been received!" + ) + self.received[message_id] = reception_status + awaiting = self.awaiting.get(message_id) + + if awaiting: + awaiting.set() + del self.awaiting[message_id] diff --git a/src/s2python/s2_connection.py b/src/s2python/s2_connection.py new file mode 100644 index 0000000..c168c8d --- /dev/null +++ b/src/s2python/s2_connection.py @@ -0,0 +1,240 @@ +import asyncio +import inspect +import logging +import threading +import uuid +from dataclasses import dataclass +from typing import Optional, List, Type, Dict, Callable, Awaitable, Union + +from websockets.asyncio.client import ClientConnection as WSConnection, connect as ws_connect + +from s2python.common import ( + ReceptionStatusValues, + ReceptionStatus, + Handshake, + EnergyManagementRole, + Role, + HandshakeResponse, + ResourceManagerDetails, + Duration, + Currency, + SelectControlType, +) +from s2python.reception_status_awaiter import ReceptionStatusAwaiter +from s2python.s2_control_type import S2ControlType +from s2python.s2_parser import S2Parser +from s2python.validate_values_mixin import S2Message + + +logger = logging.getLogger("s2python") + + +@dataclass +class AssetDetails: + resource_id: str + + instruction_processing_delay: Duration + roles: List[Role] = None + currency: Optional[Currency] = None + + name: Optional[str] = None + manufacturer: Optional[str] = None + model: Optional[str] = None + firmware_version: Optional[str] = None + serial_number: Optional[str] = None + + def to_resource_manager_details( + self, control_types: List[S2ControlType] + ) -> ResourceManagerDetails: + return ResourceManagerDetails( + available_control_types=[ + control_type.get_protocol_control_type() for control_type in control_types + ], + currency=self.currency, + firmware_version=self.firmware_version, + instruction_processing_delay=self.instruction_processing_delay, + manufacturer=self.manufacturer, + message_id=uuid.uuid4(), + model=self.model, + name=self.name, + provides_forecast=True, # TODO + provides_power_measurement_types=[], # TODO + resource_id=self.resource_id, + roles=self.roles, + serial_number=self.serial_number, + ) + + +S2MessageHandler = Union[ + Callable[["S2Connection", S2Message], None], + Callable[["S2Connection", S2Message], Awaitable[None]], +] + + +class MessageHandlers: + handlers: Dict[Type[S2Message], S2MessageHandler] + + def __init__(self): + self.handlers = {} + + async def handle_message(self, connection: "S2Connection", msg: S2Message) -> None: + """Handle the S2 message using the registered handler. + + :param connection: The S2 conncetion the `msg` is received from. + :param msg: The S2 message + """ + handler = self.handlers.get(type(msg)) + + if handler: + if inspect.iscoroutinefunction(handler): + await handler(connection, msg) + else: + handler(connection, msg) + else: + logger.warning( + "Received a message of type %s but no handler is registered. Ignoring the message.", + type(msg), + ) + + def register_handler(self, msg_type: Type[S2Message], handler: S2MessageHandler) -> None: + """Register a coroutine function or a normal function as the handler for a specific S2 message type. + + :param msg_type: The S2 message type to attach the handler to. + :param handler: The function (asynchronuous or normal) which should handle the S2 message. + """ + self.handlers[msg_type] = handler + + +class S2Connection: + url: str + reception_status_awaiter: ReceptionStatusAwaiter + ws: Optional[WSConnection] + s2_parser: S2Parser + control_types: List[S2ControlType] + role: EnergyManagementRole + asset_details: AssetDetails + + _thread: threading.Thread + _received_messages: asyncio.Queue + _handlers: MessageHandlers + _receiver_task: asyncio.Task + _current_control_type: Optional[S2ControlType] + + def __init__( + self, + url: str, + role: EnergyManagementRole, + control_types: List[S2ControlType], + asset_details: AssetDetails, + ): + self.url = url + self.reception_status_awaiter = ReceptionStatusAwaiter() + self.s2_parser = S2Parser() + + self._received_messages = asyncio.Queue() + self._handlers = MessageHandlers() + self._current_control_type = None + + self.control_types = control_types + self.role = role + self.asset_details = asset_details + + self._handlers.register_handler(SelectControlType, self.handle_select_control_type) + + def start_as_rm(self) -> None: + self._thread = threading.Thread(target=self._run_as_rm()) + self._thread.start() + + def _run_as_rm(self): + eventloop = asyncio.new_event_loop() + eventloop.run_until_complete(self.connect_as_rm()) + eventloop.run_until_complete(self._handle_received_messages()) + + async def connect_as_rm(self) -> None: + self.ws = await ws_connect(uri=self.url) + self._receiver_task = asyncio.create_task(self._received_messages) + + await self.send_msg_and_await_reception_status( + Handshake(message_id=uuid.uuid4(), role=self.role, supported_protocol_versions=[]) + ) + + logger.debug("Send handshake to CEM. Waiting for Handshake and HandshakeResponse from CEM.") + cem_handshake_responses = [self._receive_next_message(), self._receive_next_message()] + handshake_response = next( + filter(lambda m: isinstance(m, HandshakeResponse), cem_handshake_responses), None + ) + cem_handshake = next( + filter(lambda m: isinstance(m, Handshake), cem_handshake_responses), None + ) + + logger.debug( + "CEM supports S2 protocol versions: %s. CEM selected to use version %s", + cem_handshake.supported_protocol_versions, + handshake_response.selected_protocol_version, + ) + logger.debug("Handshake complete. Sending first ResourceManagerDetails.") + + await self.send_msg_and_await_reception_status( + self.asset_details.to_resource_manager_details(self.control_types) + ) + + async def handle_select_control_type(self, _: "S2Connection", message: S2Message) -> None: + logger.debug("CEM selected control type %s. Activating control type.", message.control_type) + + selected_control_type: S2ControlType = next( + filter( + lambda c: c.get_protocol_control_type() == message.control_type, self.control_types + ), + None, + ) + + if self._current_control_type is not None: + self._current_control_type.deactivate() + + self._current_control_type = selected_control_type + + if self._current_control_type is not None: + self._current_control_type.activate() + self._current_control_type.register_handlers(self._handlers) + + async def _receive_next_message(self) -> S2Message: + """Receive next non-ReceptionStatus message. + + :return: The next S2 message which is not a ReceptionStatus. + """ + return await self._received_messages.get() + + async def _receive_messages(self) -> None: + """Receives all incoming messages in the form of a generator. + + Will also receive the ReceptionStatus messages but instead of yielding these messages, they are routed + to any calls of `send_msg_and_await_reception_status`. + """ + async for message in self.ws: + s2_msg: S2Message = self.s2_parser.parse_as_any_message(message) + + if isinstance(s2_msg, ReceptionStatus): + await self.reception_status_awaiter.receive_reception_status(s2_msg) + else: + await self._received_messages.put(s2_msg) + + async def _send_and_forget(self, s2_msg: S2Message) -> None: + await self.ws.send(s2_msg.to_json()) + + async def send_msg_and_await_reception_status( + self, s2_msg: S2Message, timeout_reception_status: float = 5.0, raise_on_error: bool = True + ) -> S2Message: + await self._send_and_forget(s2_msg) + reception_status = await self.reception_status_awaiter.wait_for_reception_status( + s2_msg.message_id, timeout_reception_status + ) + + if reception_status.status != ReceptionStatusValues.OK and raise_on_error: + raise RuntimeError(f"ReceptionStatus was not OK but rather {reception_status.status}") + + return reception_status + + async def _handle_received_messages(self) -> None: + while True: + msg = await self._received_messages.get() + await self._handlers.handle_message(self, msg) diff --git a/src/s2python/s2_control_type.py b/src/s2python/s2_control_type.py new file mode 100644 index 0000000..f792ada --- /dev/null +++ b/src/s2python/s2_control_type.py @@ -0,0 +1,53 @@ +import abc +import typing + +from s2python.common import ControlType as ProtocolControlType +from s2python.frbc import FRBCInstruction + +if typing.TYPE_CHECKING: + from s2python.s2_connection import S2Connection, MessageHandlers + + +class S2ControlType(abc.ABC): + @abc.abstractmethod + def get_protocol_control_type(self) -> ProtocolControlType: ... + + @abc.abstractmethod + def register_handlers(self, handlers: MessageHandlers) -> None: ... + + @abc.abstractmethod + def activate(self) -> None: ... + + @abc.abstractmethod + def deactivate(self) -> None: ... + + +class FRBCControlType(S2ControlType): + def get_protocol_control_type(self) -> ProtocolControlType: + return ProtocolControlType.FILL_RATE_BASED_CONTROL + + def register_handlers(self, handlers: MessageHandlers) -> None: + handlers.register_handler(FRBCInstruction, self.handle_instruction) + + @abc.abstractmethod + def handle_instruction(self, conn: S2Connection, msg: FRBCInstruction) -> None: ... + + @abc.abstractmethod + def activate(self) -> None: ... + + @abc.abstractmethod + def deactivate(self) -> None: ... + + +class NoControlControlType(S2ControlType): + def get_protocol_control_type(self) -> ProtocolControlType: + return ProtocolControlType.NOT_CONTROLABLE + + def register_handlers(self, handlers: MessageHandlers) -> None: + pass + + @abc.abstractmethod + def activate(self) -> None: ... + + @abc.abstractmethod + def deactivate(self) -> None: ... diff --git a/src/s2python/validate_values_mixin.py b/src/s2python/validate_values_mixin.py index 14e305a..46d83a5 100644 --- a/src/s2python/validate_values_mixin.py +++ b/src/s2python/validate_values_mixin.py @@ -30,70 +30,73 @@ MappingIntStrAny = Mapping[IntStr, Any] -class SupportsValidation(Protocol[B_co]): - # ValidateValuesMixin methods - def to_json(self) -> str: - ... - - def to_dict(self) -> Dict: - ... - - @classmethod - def from_json(cls, json_str: str) -> B_co: - ... - - @classmethod - def from_dict(cls, json_dict: Dict) -> B_co: - ... - - # Pydantic methods - def json( # pylint: disable=too-many-arguments - self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = False, - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - encoder: Optional[Callable[[Any], Any]] = None, - models_as_dict: bool = True, - **dumps_kwargs: Any, - ) -> str: - ... - - def dict( # pylint: disable=too-many-arguments - self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = False, - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - ) -> Dict[str, Any]: - ... - - @classmethod - def parse_raw( # pylint: disable=too-many-arguments - cls, - b: StrBytes, - *, - content_type: str = ..., - encoding: str = ..., - proto: PydanticProtocol = ..., - allow_pickle: bool = ..., - ) -> B_co: - ... - - @classmethod - def parse_obj(cls, obj: Any) -> "B_co": - ... - - -C = TypeVar("C", bound="SupportsValidation") +# class SupportsValidation(Protocol[B_co]): +# def lets_disable_this(self) -> None: +# pass +# +# # ValidateValuesMixin methods +# def to_json(self) -> str: +# ... +# +# def to_dict(self) -> Dict: +# ... +# +# @classmethod +# def from_json(cls, json_str: str) -> B_co: +# ... +# +# @classmethod +# def from_dict(cls, json_dict: Dict) -> B_co: +# ... +# +# # Pydantic methods +# def json( # pylint: disable=too-many-arguments +# self, +# *, +# include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, +# exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, +# by_alias: bool = False, +# skip_defaults: Optional[bool] = None, +# exclude_unset: bool = False, +# exclude_defaults: bool = False, +# exclude_none: bool = False, +# encoder: Optional[Callable[[Any], Any]] = None, +# models_as_dict: bool = True, +# **dumps_kwargs: Any, +# ) -> str: +# ... +# +# def dict( # pylint: disable=too-many-arguments +# self, +# *, +# include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, +# exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, +# by_alias: bool = False, +# skip_defaults: Optional[bool] = None, +# exclude_unset: bool = False, +# exclude_defaults: bool = False, +# exclude_none: bool = False, +# ) -> Dict[str, Any]: +# ... +# +# @classmethod +# def parse_raw( # pylint: disable=too-many-arguments +# cls, +# b: StrBytes, +# *, +# content_type: str = ..., +# encoding: str = ..., +# proto: PydanticProtocol = ..., +# allow_pickle: bool = ..., +# ) -> B_co: +# ... +# +# @classmethod +# def parse_obj(cls, obj: Any) -> "B_co": +# ... +# +# +C = TypeVar("C") # , bound="SupportsValidation") class ValidateValuesMixin(Generic[C]): @@ -101,9 +104,7 @@ def to_json(self: C) -> str: try: return self.json(by_alias=True, exclude_none=True) except (ValidationError, TypeError) as e: - raise S2ValidationError( - self, "Pydantic raised a format validation error." - ) from e + raise S2ValidationError(self, "Pydantic raised a format validation error.") from e def to_dict(self: C) -> dict: return self.dict() diff --git a/tests/unit/reception_status_awaiter_test.py b/tests/unit/reception_status_awaiter_test.py new file mode 100644 index 0000000..6bfd92e --- /dev/null +++ b/tests/unit/reception_status_awaiter_test.py @@ -0,0 +1,224 @@ +"""Tests for ReceptionStatusAwaiter. + +Copied from https://github.com/flexiblepower/s2-analyzer/blob/main/backend/test/s2_analyzer_backend/reception_status_awaiter_test.py under Apache2 license on 31-08-2024. +""" + +import asyncio +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, Mock + +from s2python.reception_status_awaiter import ReceptionStatusAwaiter + + +class ReceptionStatusAwaiterTest(IsolatedAsyncioTestCase): + async def test__wait_for_reception_status__receive_while_waiting(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + message_id = "1" + s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": message_id, + "status": "OK", + } + + # Act + wait_task = asyncio.create_task(awaiter.wait_for_reception_status(message_id)) + should_be_waiting_still = not wait_task.done() + await awaiter.receive_reception_status(s2_reception_status) + await wait_task + received_s2_reception_status = wait_task.result() + + # Assert + expected_s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": "1", + "status": "OK", + } + + self.assertTrue(should_be_waiting_still) + self.assertEqual(expected_s2_reception_status, received_s2_reception_status) + + async def test__wait_for_reception_status__already_received(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + message_id = "1" + s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": message_id, + "status": "OK", + } + + # Act + await awaiter.receive_reception_status(s2_reception_status) + received_s2_reception_status = await awaiter.wait_for_reception_status(message_id) + + # Assert + expected_s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": "1", + "status": "OK", + } + self.assertEqual(expected_s2_reception_status, received_s2_reception_status) + + async def test__wait_for_reception_status__multiple_receive_while_waiting(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + message_id = "1" + s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": message_id, + "status": "OK", + } + + # Act + wait_task_1 = asyncio.create_task(awaiter.wait_for_reception_status(message_id)) + wait_task_2 = asyncio.create_task(awaiter.wait_for_reception_status(message_id)) + should_be_waiting_still_1 = not wait_task_1.done() + should_be_waiting_still_2 = not wait_task_2.done() + await awaiter.receive_reception_status(s2_reception_status) + await wait_task_1 + await wait_task_2 + received_s2_reception_status_1 = wait_task_1.result() + received_s2_reception_status_2 = wait_task_2.result() + + # Assert + expected_s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": "1", + "status": "OK", + } + + self.assertTrue(should_be_waiting_still_1) + self.assertTrue(should_be_waiting_still_2) + self.assertEqual(expected_s2_reception_status, received_s2_reception_status_1) + self.assertEqual(expected_s2_reception_status, received_s2_reception_status_2) + + async def test__receive_reception_status__wrong_message(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + s2_msg = {"message_type": "NotAReceptionStatus", "subject_message_id": "1", "status": "OK"} + + # Act / Assert + with self.assertRaises(RuntimeError): + await awaiter.receive_reception_status(s2_msg) + + async def test__receive_reception_status__received_duplicate(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": "1", + "status": "OK", + } + + # Act / Assert + await awaiter.receive_reception_status(s2_reception_status) + with self.assertRaises(RuntimeError): + await awaiter.receive_reception_status(s2_reception_status) + + async def test__receive_reception_status__receive_no_awaiting(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": "1", + "status": "OK", + } + + # Act + await awaiter.receive_reception_status(s2_reception_status) + + # Assert + expected_received = { + "1": {"message_type": "ReceptionStatus", "subject_message_id": "1", "status": "OK"} + } + self.assertEqual(awaiter.received, expected_received) + self.assertEqual(awaiter.awaiting, {}) + + async def test__receive_reception_status__receive_with_awaiting(self): + # Arrange + awaiter = ReceptionStatusAwaiter() + awaiting_event = asyncio.Event() + awaiter.awaiting = {"1": awaiting_event} + s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": "1", + "status": "OK", + } + + # Act + should_not_be_set = not awaiting_event.is_set() + await awaiter.receive_reception_status(s2_reception_status) + should_be_set = awaiting_event.is_set() + + # Assert + expected_received = { + "1": {"message_type": "ReceptionStatus", "subject_message_id": "1", "status": "OK"} + } + + self.assertTrue(should_not_be_set) + self.assertTrue(should_be_set) + self.assertEqual(awaiter.received, expected_received) + self.assertEqual(awaiter.awaiting, {}) + + async def test__send_and_await_reception_status__receive_while_waiting(self): + # Arrange + conn = Mock() + awaiter = ReceptionStatusAwaiter() + message_id = "1" + s2_message = { + "message_type": "Handshake", + "message_id": message_id, + "role": "RM", + "supported_protocol_versions": ["1.0"], + } + s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": message_id, + "status": "OK", + } + + # Act + wait_task = asyncio.create_task( + awaiter.send_and_await_reception_status(conn, s2_message, True) + ) + should_be_waiting_still = not wait_task.done() + await awaiter.receive_reception_status(s2_reception_status) + await wait_task + received_s2_reception_status = wait_task.result() + + # Assert + expected_s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": "1", + "status": "OK", + } + + self.assertTrue(should_be_waiting_still) + self.assertEqual(expected_s2_reception_status, received_s2_reception_status) + + async def test__send_and_await_reception_status__receive_while_waiting_not_okay(self): + # Arrange + conn = Mock() + awaiter = ReceptionStatusAwaiter() + message_id = "1" + s2_message = { + "message_type": "Handshake", + "message_id": message_id, + "role": "RM", + "supported_protocol_versions": ["1.0"], + } + s2_reception_status = { + "message_type": "ReceptionStatus", + "subject_message_id": message_id, + "status": "INVALID_MESSAGE", + } + + # Act / Assert + wait_task = asyncio.create_task( + awaiter.send_and_await_reception_status(conn, s2_message, True) + ) + await awaiter.receive_reception_status(s2_reception_status) + + with self.assertRaises(RuntimeError): + await wait_task From b7af280d773a8560b11acb1dbd30662d0a82a9be Mon Sep 17 00:00:00 2001 From: Sebastiaan la Fleur Date: Fri, 6 Sep 2024 11:52:10 +0200 Subject: [PATCH 2/4] add_conn: revert temporary disable of SupportsValidation protocol. --- src/s2python/validate_values_mixin.py | 135 +++++++++++++------------- 1 file changed, 67 insertions(+), 68 deletions(-) diff --git a/src/s2python/validate_values_mixin.py b/src/s2python/validate_values_mixin.py index 46d83a5..14e305a 100644 --- a/src/s2python/validate_values_mixin.py +++ b/src/s2python/validate_values_mixin.py @@ -30,73 +30,70 @@ MappingIntStrAny = Mapping[IntStr, Any] -# class SupportsValidation(Protocol[B_co]): -# def lets_disable_this(self) -> None: -# pass -# -# # ValidateValuesMixin methods -# def to_json(self) -> str: -# ... -# -# def to_dict(self) -> Dict: -# ... -# -# @classmethod -# def from_json(cls, json_str: str) -> B_co: -# ... -# -# @classmethod -# def from_dict(cls, json_dict: Dict) -> B_co: -# ... -# -# # Pydantic methods -# def json( # pylint: disable=too-many-arguments -# self, -# *, -# include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, -# exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, -# by_alias: bool = False, -# skip_defaults: Optional[bool] = None, -# exclude_unset: bool = False, -# exclude_defaults: bool = False, -# exclude_none: bool = False, -# encoder: Optional[Callable[[Any], Any]] = None, -# models_as_dict: bool = True, -# **dumps_kwargs: Any, -# ) -> str: -# ... -# -# def dict( # pylint: disable=too-many-arguments -# self, -# *, -# include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, -# exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, -# by_alias: bool = False, -# skip_defaults: Optional[bool] = None, -# exclude_unset: bool = False, -# exclude_defaults: bool = False, -# exclude_none: bool = False, -# ) -> Dict[str, Any]: -# ... -# -# @classmethod -# def parse_raw( # pylint: disable=too-many-arguments -# cls, -# b: StrBytes, -# *, -# content_type: str = ..., -# encoding: str = ..., -# proto: PydanticProtocol = ..., -# allow_pickle: bool = ..., -# ) -> B_co: -# ... -# -# @classmethod -# def parse_obj(cls, obj: Any) -> "B_co": -# ... -# -# -C = TypeVar("C") # , bound="SupportsValidation") +class SupportsValidation(Protocol[B_co]): + # ValidateValuesMixin methods + def to_json(self) -> str: + ... + + def to_dict(self) -> Dict: + ... + + @classmethod + def from_json(cls, json_str: str) -> B_co: + ... + + @classmethod + def from_dict(cls, json_dict: Dict) -> B_co: + ... + + # Pydantic methods + def json( # pylint: disable=too-many-arguments + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + encoder: Optional[Callable[[Any], Any]] = None, + models_as_dict: bool = True, + **dumps_kwargs: Any, + ) -> str: + ... + + def dict( # pylint: disable=too-many-arguments + self, + *, + include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> Dict[str, Any]: + ... + + @classmethod + def parse_raw( # pylint: disable=too-many-arguments + cls, + b: StrBytes, + *, + content_type: str = ..., + encoding: str = ..., + proto: PydanticProtocol = ..., + allow_pickle: bool = ..., + ) -> B_co: + ... + + @classmethod + def parse_obj(cls, obj: Any) -> "B_co": + ... + + +C = TypeVar("C", bound="SupportsValidation") class ValidateValuesMixin(Generic[C]): @@ -104,7 +101,9 @@ def to_json(self: C) -> str: try: return self.json(by_alias=True, exclude_none=True) except (ValidationError, TypeError) as e: - raise S2ValidationError(self, "Pydantic raised a format validation error.") from e + raise S2ValidationError( + self, "Pydantic raised a format validation error." + ) from e def to_dict(self: C) -> dict: return self.dict() From 80219698e1e68d9d7c2d0682b1137b3cc4ee42e9 Mon Sep 17 00:00:00 2001 From: Sebastiaan la Fleur Date: Mon, 9 Sep 2024 18:20:00 +0200 Subject: [PATCH 3/4] Rough first version working. --- src/s2python/example_frbc_rm.py | 154 +++++++++++- src/s2python/reception_status_awaiter.py | 16 +- src/s2python/s2_connection.py | 292 +++++++++++++++++++---- src/s2python/s2_control_type.py | 23 +- src/s2python/validate_values_mixin.py | 10 +- src/s2python/version.py | 2 + 6 files changed, 411 insertions(+), 86 deletions(-) diff --git a/src/s2python/example_frbc_rm.py b/src/s2python/example_frbc_rm.py index c8efd0a..3da66c2 100644 --- a/src/s2python/example_frbc_rm.py +++ b/src/s2python/example_frbc_rm.py @@ -1,26 +1,150 @@ +import logging +import sys import uuid +import signal +import datetime +from typing import Callable -from s2python.common import EnergyManagementRole, Duration, Role, RoleType, Commodity, Currency -from s2python.frbc import FRBCInstruction, FRBCSystemDescription +from s2python.common import ( + EnergyManagementRole, + Duration, + Role, + RoleType, + Commodity, + Currency, + NumberRange, + PowerRange, + CommodityQuantity, +) +from s2python.frbc import ( + FRBCInstruction, + FRBCSystemDescription, + FRBCActuatorDescription, + FRBCStorageDescription, + FRBCOperationMode, + FRBCOperationModeElement, + FRBCFillLevelTargetProfile, + FRBCFillLevelTargetProfileElement, + FRBCStorageStatus, + FRBCActuatorStatus, +) from s2python.s2_connection import S2Connection, AssetDetails -from s2python.s2_control_type import FRBCControlType +from s2python.s2_control_type import FRBCControlType, NoControlControlType + + +logger = logging.getLogger("s2python") +logger.addHandler(logging.StreamHandler(sys.stdout)) +logger.setLevel(logging.DEBUG) class MyFRBCControlType(FRBCControlType): - def handle_instruction(self, conn: S2Connection, msg: FRBCInstruction) -> None: + def handle_instruction( + self, conn: S2Connection, msg: FRBCInstruction, send_okay: Callable[[], None] + ) -> None: print(f"I have received the message {msg} from {conn}") - def activate(self) -> None: - print("It is now activated.") + def activate(self, conn: S2Connection) -> None: + print("The control type FRBC is now activated.") + + print("Time to send a FRBC SystemDescription") + actuator_id = uuid.uuid4() + operation_mode_id = uuid.uuid4() + conn.send_msg_and_await_reception_status_sync( + FRBCSystemDescription( + message_id=uuid.uuid4(), + valid_from=datetime.datetime.now(tz=datetime.timezone.utc), + actuators=[ + FRBCActuatorDescription( + id=actuator_id, + operation_modes=[ + FRBCOperationMode( + id=operation_mode_id, + elements=[ + FRBCOperationModeElement( + fill_level_range=NumberRange( + start_of_range=0.0, end_of_range=100.0 + ), + fill_rate=NumberRange( + start_of_range=-5.0, end_of_range=5.0 + ), + power_ranges=[ + PowerRange( + start_of_range=-200.0, + end_of_range=200.0, + commodity_quantity=CommodityQuantity.ELECTRIC_POWER_L1, + ) + ], + ) + ], + diagnostic_label="Load & unload battery", + abnormal_condition_only=False, + ) + ], + transitions=[], + timers=[], + supported_commodities=[Commodity.ELECTRICITY], + ) + ], + storage=FRBCStorageDescription( + fill_level_range=NumberRange(start_of_range=0.0, end_of_range=100.0), + fill_level_label="%", + diagnostic_label="Imaginary battery", + provides_fill_level_target_profile=True, + provides_leakage_behaviour=False, + provides_usage_forecast=False, + ), + ) + ) + print("Also send the target profile") + + conn.send_msg_and_await_reception_status_sync( + FRBCFillLevelTargetProfile( + message_id=uuid.uuid4(), + start_time=datetime.datetime.now(tz=datetime.timezone.utc), + elements=[ + FRBCFillLevelTargetProfileElement( + duration=Duration.from_milliseconds(30_000), + fill_level_range=NumberRange(start_of_range=20.0, end_of_range=30.0), + ), + FRBCFillLevelTargetProfileElement( + duration=Duration.from_milliseconds(300_000), + fill_level_range=NumberRange(start_of_range=40.0, end_of_range=50.0), + ), + ], + ) + ) + + print("Also send the storage status.") + conn.send_msg_and_await_reception_status_sync( + FRBCStorageStatus(message_id=uuid.uuid4(), present_fill_level=10.0) + ) + + print("Also send the actuator status.") + conn.send_msg_and_await_reception_status_sync( + FRBCActuatorStatus( + message_id=uuid.uuid4(), + actuator_id=actuator_id, + active_operation_mode_id=operation_mode_id, + operation_mode_factor=0.5, + ) + ) - def deactivate(self) -> None: - print("It is now deactivated.") + def deactivate(self, conn: S2Connection) -> None: + print("The control type FRBC is now deactivated.") + + +class MyNoControlControlType(NoControlControlType): + def activate(self, conn: S2Connection) -> None: + print("The control type NoControl is now activated.") + + def deactivate(self, conn: S2Connection) -> None: + print("The control type NoControl is now deactivated.") s2_conn = S2Connection( - url="http://cem_is_here.com:8080/", + url="ws://localhost:8001/backend/rm/s2python-frbc/cem/dummy_model/ws", role=EnergyManagementRole.RM, - control_types=[MyFRBCControlType()], + control_types=[MyFRBCControlType(), MyNoControlControlType()], asset_details=AssetDetails( resource_id=str(uuid.uuid4()), name="Some asset", @@ -30,5 +154,13 @@ def deactivate(self) -> None: ), ) + +def stop(signal_num, _current_stack_frame): + print(f"Received signal {signal_num}. Will stop S2 connection.") + s2_conn.stop() + + +signal.signal(signal.SIGINT, stop) +signal.signal(signal.SIGTERM, stop) + s2_conn.start_as_rm() -s2_conn.send_msg_and_await_reception_status(FRBCSystemDescription(...)) diff --git a/src/s2python/reception_status_awaiter.py b/src/s2python/reception_status_awaiter.py index 824aa2e..090cfd5 100644 --- a/src/s2python/reception_status_awaiter.py +++ b/src/s2python/reception_status_awaiter.py @@ -22,7 +22,6 @@ def __init__(self): async def wait_for_reception_status( self, message_id: uuid.UUID, timeout_reception_status: float ) -> ReceptionStatus: - # TODO Add timeout if message_id in self.received: reception_status = self.received[message_id] else: @@ -32,8 +31,7 @@ async def wait_for_reception_status( received_event = asyncio.Event() self.awaiting[message_id] = received_event - async with asyncio.timeout(timeout_reception_status): - await received_event.wait() + await asyncio.wait_for(received_event.wait(), timeout_reception_status) reception_status = self.received.get(message_id) if message_id in self.awaiting: @@ -42,19 +40,19 @@ async def wait_for_reception_status( return reception_status async def receive_reception_status(self, reception_status: ReceptionStatus) -> None: - if reception_status.get("message_type") != "ReceptionStatus": + if not isinstance(reception_status, ReceptionStatus): raise RuntimeError( f"Expected a ReceptionStatus but received message {reception_status}" ) - message_id = reception_status["subject_message_id"] if reception_status.subject_message_id in self.received: raise RuntimeError( - f"ReceptationStatus for message_subject_id {message_id} has already been received!" + f"ReceptationStatus for message_subject_id {reception_status.subject_message_id} has already been received!" ) - self.received[message_id] = reception_status - awaiting = self.awaiting.get(message_id) + + self.received[reception_status.subject_message_id] = reception_status + awaiting = self.awaiting.get(reception_status.subject_message_id) if awaiting: awaiting.set() - del self.awaiting[message_id] + del self.awaiting[reception_status.subject_message_id] diff --git a/src/s2python/s2_connection.py b/src/s2python/s2_connection.py index c168c8d..1e6ef1b 100644 --- a/src/s2python/s2_connection.py +++ b/src/s2python/s2_connection.py @@ -1,5 +1,6 @@ import asyncio import inspect +import json import logging import threading import uuid @@ -20,11 +21,13 @@ Currency, SelectControlType, ) +from s2python.generated.gen_s2 import CommodityQuantity from s2python.reception_status_awaiter import ReceptionStatusAwaiter from s2python.s2_control_type import S2ControlType from s2python.s2_parser import S2Parser +from s2python.s2_validation_error import S2ValidationError from s2python.validate_values_mixin import S2Message - +from s2python.version import S2_VERSION logger = logging.getLogger("s2python") @@ -58,7 +61,7 @@ def to_resource_manager_details( model=self.model, name=self.name, provides_forecast=True, # TODO - provides_power_measurement_types=[], # TODO + provides_power_measurement_types=[CommodityQuantity.ELECTRIC_POWER_L1], # TODO resource_id=self.resource_id, roles=self.roles, serial_number=self.serial_number, @@ -66,8 +69,8 @@ def to_resource_manager_details( S2MessageHandler = Union[ - Callable[["S2Connection", S2Message], None], - Callable[["S2Connection", S2Message], Awaitable[None]], + Callable[["S2Connection", S2Message, Callable[[], None]], None], + Callable[["S2Connection", S2Message, Awaitable[None]], Awaitable[None]], ] @@ -84,12 +87,62 @@ async def handle_message(self, connection: "S2Connection", msg: S2Message) -> No :param msg: The S2 message """ handler = self.handlers.get(type(msg)) - if handler: - if inspect.iscoroutinefunction(handler): - await handler(connection, msg) - else: - handler(connection, msg) + status_is_send = threading.Event() + + try: + if inspect.iscoroutinefunction(handler): + + async def send_okay(): + status_is_send.set() + + await connection.respond_with_reception_status( + subject_message_id=str(msg.message_id), + status=ReceptionStatusValues.OK, + diagnostic_label="Processed okay.", + ) + + await handler(connection, msg, send_okay()) + + if not status_is_send.is_set(): + logger.warning( + "Handler for message %s did not call send_okay / function to send the ReceptionStatus. " + "Sending it now.", + type(msg), + ) + await send_okay() + else: + eventloop = asyncio.get_event_loop() + + def do_message(): + def send_okay(): + status_is_send.set() + + connection.respond_with_reception_status_sync( + subject_message_id=str(msg.message_id), + status=ReceptionStatusValues.OK, + diagnostic_label="Processed okay.", + ) + + handler(connection, msg, send_okay) + + if not status_is_send.is_set(): + logger.warning( + "Handler for message %s did not call send_okay / function to send the ReceptionStatus. " + "Sending it now.", + type(msg), + ) + send_okay() + + await eventloop.run_in_executor(executor=None, func=do_message) + except Exception: + if not status_is_send.is_set(): + await connection.respond_with_reception_status( + subject_message_id=str(msg.message_id), + status=ReceptionStatusValues.PERMANENT_ERROR, + diagnostic_label=f"While processing message {msg.message_id} an unrecoverable error occurred.", + ) + raise else: logger.warning( "Received a message of type %s but no handler is registered. Ignoring the message.", @@ -115,10 +168,15 @@ class S2Connection: asset_details: AssetDetails _thread: threading.Thread - _received_messages: asyncio.Queue + _handlers: MessageHandlers - _receiver_task: asyncio.Task _current_control_type: Optional[S2ControlType] + _received_messages: asyncio.Queue + + _eventloop: asyncio.AbstractEventLoop + _receiver_task: Optional[asyncio.Task] + _handle_messages_task: Optional[asyncio.Task] + _stop_event: asyncio.Event def __init__( self, @@ -131,54 +189,125 @@ def __init__( self.reception_status_awaiter = ReceptionStatusAwaiter() self.s2_parser = S2Parser() - self._received_messages = asyncio.Queue() self._handlers = MessageHandlers() self._current_control_type = None + self._eventloop = asyncio.new_event_loop() + self._receiver_task = None + self._handle_messages_task = None + self.control_types = control_types self.role = role self.asset_details = asset_details - self._handlers.register_handler(SelectControlType, self.handle_select_control_type) + self._handlers.register_handler(SelectControlType, self.handle_select_control_type_as_rm) + self._handlers.register_handler(Handshake, self.handle_handshake) + self._handlers.register_handler(HandshakeResponse, self.handle_handshake_response_as_rm) def start_as_rm(self) -> None: - self._thread = threading.Thread(target=self._run_as_rm()) + self._thread = threading.Thread(target=self._run_eventloop) self._thread.start() + logger.debug("Started eventloop thread!") - def _run_as_rm(self): - eventloop = asyncio.new_event_loop() - eventloop.run_until_complete(self.connect_as_rm()) - eventloop.run_until_complete(self._handle_received_messages()) + def _run_eventloop(self) -> None: + logger.debug("Starting eventloop") + self._eventloop.run_until_complete(self._run_as_rm()) - async def connect_as_rm(self) -> None: - self.ws = await ws_connect(uri=self.url) - self._receiver_task = asyncio.create_task(self._received_messages) + def stop(self) -> None: + asyncio.run_coroutine_threadsafe(self._do_stop(), self._eventloop) # TODO .result() - await self.send_msg_and_await_reception_status( - Handshake(message_id=uuid.uuid4(), role=self.role, supported_protocol_versions=[]) - ) + async def _do_stop(self): + logger.info("Will stop the S2 connection.") + if self._background_tasks: + self._background_tasks.cancel() + self._background_tasks = None - logger.debug("Send handshake to CEM. Waiting for Handshake and HandshakeResponse from CEM.") - cem_handshake_responses = [self._receive_next_message(), self._receive_next_message()] - handshake_response = next( - filter(lambda m: isinstance(m, HandshakeResponse), cem_handshake_responses), None + if self.ws: + await self.ws.close() + await self.ws.wait_closed() + + async def _run_as_rm(self): + logger.debug("Connecting as S2 resource manager.") + self._received_messages = asyncio.Queue() + await self.connect_ws() + + self._background_tasks = self._eventloop.create_task( + asyncio.wait( + (self._receive_messages(), self._handle_received_messages()), + return_when=asyncio.FIRST_EXCEPTION, + ) ) - cem_handshake = next( - filter(lambda m: isinstance(m, Handshake), cem_handshake_responses), None + + await self.connect_as_rm() + done: List[asyncio.Task] + pending: List[asyncio.Task] + (done, pending) = await self._background_tasks + + for task in done: + task.result() + + for task in pending: + task.cancel() + + async def connect_ws(self) -> None: + self.ws = await ws_connect(uri=self.url) + + async def connect_as_rm(self) -> None: + await self.send_msg_and_await_reception_status_async( + Handshake( + message_id=uuid.uuid4(), role=self.role, supported_protocol_versions=[S2_VERSION] + ) ) + logger.debug("Send handshake to CEM. Expecting Handshake and HandshakeResponse from CEM.") + + async def handle_handshake( + self, _: "S2Connection", message: S2Message, send_okay: Awaitable[None] + ) -> None: + if not isinstance(message, Handshake): + logger.error( + "Handler for Handshake received a message of the wrong type: %s", type(message) + ) + return logger.debug( - "CEM supports S2 protocol versions: %s. CEM selected to use version %s", - cem_handshake.supported_protocol_versions, - handshake_response.selected_protocol_version, + "%s supports S2 protocol versions: %s", + message.role, + message.supported_protocol_versions, ) + await send_okay + + async def handle_handshake_response_as_rm( + self, _: "S2Connection", message: S2Message, send_okay: Awaitable[None] + ) -> None: + if not isinstance(message, HandshakeResponse): + logger.error( + "Handler for HandshakeResponse received a message of the wrong type: %s", + type(message), + ) + return + + logger.debug("Received HandshakeResponse %s", message.to_json()) + + logger.debug("CEM selected to use version %s", message.selected_protocol_version) + await send_okay logger.debug("Handshake complete. Sending first ResourceManagerDetails.") - await self.send_msg_and_await_reception_status( + await self.send_msg_and_await_reception_status_async( self.asset_details.to_resource_manager_details(self.control_types) ) - async def handle_select_control_type(self, _: "S2Connection", message: S2Message) -> None: + async def handle_select_control_type_as_rm( + self, _: "S2Connection", message: S2Message, send_okay: Awaitable[None] + ) -> None: + if not isinstance(message, SelectControlType): + logger.error( + "Handler for SelectControlType received a message of the wrong type: %s", + type(message), + ) + return + + await send_okay + logger.debug("CEM selected control type %s. Activating control type.", message.control_type) selected_control_type: S2ControlType = next( @@ -189,42 +318,93 @@ async def handle_select_control_type(self, _: "S2Connection", message: S2Message ) if self._current_control_type is not None: - self._current_control_type.deactivate() + await self._eventloop.run_in_executor(None, self._current_control_type.deactivate, self) self._current_control_type = selected_control_type if self._current_control_type is not None: - self._current_control_type.activate() + await self._eventloop.run_in_executor(None, self._current_control_type.activate, self) self._current_control_type.register_handlers(self._handlers) - async def _receive_next_message(self) -> S2Message: - """Receive next non-ReceptionStatus message. - - :return: The next S2 message which is not a ReceptionStatus. - """ - return await self._received_messages.get() - async def _receive_messages(self) -> None: """Receives all incoming messages in the form of a generator. Will also receive the ReceptionStatus messages but instead of yielding these messages, they are routed to any calls of `send_msg_and_await_reception_status`. """ + logger.info("S2 connection has started to receive messages.") async for message in self.ws: - s2_msg: S2Message = self.s2_parser.parse_as_any_message(message) - - if isinstance(s2_msg, ReceptionStatus): - await self.reception_status_awaiter.receive_reception_status(s2_msg) + try: + s2_msg: S2Message = self.s2_parser.parse_as_any_message(message) + except json.JSONDecodeError: + await self._send_and_forget( + ReceptionStatus( + subject_message_id="00000000-0000-0000-0000-000000000000", + status=ReceptionStatusValues.INVALID_DATA, + diagnostic_label="Not valid json.", + ) + ) + except S2ValidationError as e: + json_msg = json.load(message) + message_id = json_msg.get("message_id") + if message_id: + await self.respond_with_reception_status( + subject_message_id=message_id, + status=ReceptionStatusValues.INVALID_MESSAGE, + diagnostic_label=str(e), + ) + else: + await self.respond_with_reception_status( + subject_message_id="00000000-0000-0000-0000-000000000000", + status=ReceptionStatusValues.INVALID_DATA, + diagnostic_label="Message appears valid json but could not find a message_id field.", + ) else: - await self._received_messages.put(s2_msg) + logger.debug("Received message %s", s2_msg.to_json()) + + if isinstance(s2_msg, ReceptionStatus): + logger.debug( + "Message is a reception status for %s so registering in cache.", + s2_msg.subject_message_id, + ) + await self.reception_status_awaiter.receive_reception_status(s2_msg) + else: + await self._received_messages.put(s2_msg) async def _send_and_forget(self, s2_msg: S2Message) -> None: - await self.ws.send(s2_msg.to_json()) + json_msg = s2_msg.to_json() + logger.debug("Sending message %s", json_msg) + await self.ws.send(json_msg) + + async def respond_with_reception_status( + self, subject_message_id: str, status: ReceptionStatusValues, diagnostic_label: str + ) -> None: + logger.debug("Responding to message %s with status %s", subject_message_id, status) + await self._send_and_forget( + ReceptionStatus( + subject_message_id=subject_message_id, + status=status, + diagnostic_label=diagnostic_label, + ) + ) + + def respond_with_reception_status_sync( + self, subject_message_id: str, status: ReceptionStatusValues, diagnostic_label: str + ) -> None: + asyncio.run_coroutine_threadsafe( + self.respond_with_reception_status(subject_message_id, status, diagnostic_label), + self._eventloop, + ).result() - async def send_msg_and_await_reception_status( + async def send_msg_and_await_reception_status_async( self, s2_msg: S2Message, timeout_reception_status: float = 5.0, raise_on_error: bool = True ) -> S2Message: await self._send_and_forget(s2_msg) + logger.debug( + "Waiting for ReceptionStatus for %s %s seconds", + s2_msg.message_id, + timeout_reception_status, + ) reception_status = await self.reception_status_awaiter.wait_for_reception_status( s2_msg.message_id, timeout_reception_status ) @@ -234,6 +414,16 @@ async def send_msg_and_await_reception_status( return reception_status + def send_msg_and_await_reception_status_sync( + self, s2_msg: S2Message, timeout_reception_status: float = 5.0, raise_on_error: bool = True + ) -> S2Message: + asyncio.run_coroutine_threadsafe( + self.send_msg_and_await_reception_status_async( + s2_msg, timeout_reception_status, raise_on_error + ), + self._eventloop, + ).result() + async def _handle_received_messages(self) -> None: while True: msg = await self._received_messages.get() diff --git a/src/s2python/s2_control_type.py b/src/s2python/s2_control_type.py index f792ada..f9a4545 100644 --- a/src/s2python/s2_control_type.py +++ b/src/s2python/s2_control_type.py @@ -3,6 +3,7 @@ from s2python.common import ControlType as ProtocolControlType from s2python.frbc import FRBCInstruction +from s2python.validate_values_mixin import S2Message if typing.TYPE_CHECKING: from s2python.s2_connection import S2Connection, MessageHandlers @@ -13,41 +14,43 @@ class S2ControlType(abc.ABC): def get_protocol_control_type(self) -> ProtocolControlType: ... @abc.abstractmethod - def register_handlers(self, handlers: MessageHandlers) -> None: ... + def register_handlers(self, handlers: "MessageHandlers") -> None: ... @abc.abstractmethod - def activate(self) -> None: ... + def activate(self, conn: "S2Connection") -> None: ... @abc.abstractmethod - def deactivate(self) -> None: ... + def deactivate(self, conn: "S2Connection") -> None: ... class FRBCControlType(S2ControlType): def get_protocol_control_type(self) -> ProtocolControlType: return ProtocolControlType.FILL_RATE_BASED_CONTROL - def register_handlers(self, handlers: MessageHandlers) -> None: + def register_handlers(self, handlers: "MessageHandlers") -> None: handlers.register_handler(FRBCInstruction, self.handle_instruction) @abc.abstractmethod - def handle_instruction(self, conn: S2Connection, msg: FRBCInstruction) -> None: ... + def handle_instruction( + self, conn: "S2Connection", msg: S2Message, send_okay: typing.Callable[[], None] + ) -> None: ... @abc.abstractmethod - def activate(self) -> None: ... + def activate(self, conn: "S2Connection") -> None: ... @abc.abstractmethod - def deactivate(self) -> None: ... + def deactivate(self, conn: "S2Connection") -> None: ... class NoControlControlType(S2ControlType): def get_protocol_control_type(self) -> ProtocolControlType: return ProtocolControlType.NOT_CONTROLABLE - def register_handlers(self, handlers: MessageHandlers) -> None: + def register_handlers(self, handlers: "MessageHandlers") -> None: pass @abc.abstractmethod - def activate(self) -> None: ... + def activate(self, conn: "S2Connection") -> None: ... @abc.abstractmethod - def deactivate(self) -> None: ... + def deactivate(self, conn: "S2Connection") -> None: ... diff --git a/src/s2python/validate_values_mixin.py b/src/s2python/validate_values_mixin.py index c59cc7f..c240610 100644 --- a/src/s2python/validate_values_mixin.py +++ b/src/s2python/validate_values_mixin.py @@ -1,3 +1,4 @@ +import uuid from typing import ( TypeVar, Generic, @@ -15,10 +16,7 @@ ) from typing_extensions import Self -from pydantic import ( # pylint: disable=no-name-in-module - BaseModel, - ValidationError, -) +from pydantic import BaseModel, ValidationError # pylint: disable=no-name-in-module from pydantic.main import IncEx from pydantic.v1.error_wrappers import display_errors # pylint: disable=no-name-in-module @@ -104,7 +102,9 @@ def to_json(self: C) -> str: try: return self.model_dump_json(by_alias=True, exclude_none=True) except (ValidationError, TypeError) as e: - raise S2ValidationError(type(self), self, "Pydantic raised a format validation error.", e) from e + raise S2ValidationError( + type(self), self, "Pydantic raised a format validation error.", e + ) from e def to_dict(self: C) -> Dict: return self.model_dump() diff --git a/src/s2python/version.py b/src/s2python/version.py index 6c5007c..3789fe8 100644 --- a/src/s2python/version.py +++ b/src/s2python/version.py @@ -1 +1,3 @@ VERSION = "0.2.0" + +S2_VERSION = "0.0.2-beta" From eb45bdc39539e36e1b80398961e1eb8b60da7866 Mon Sep 17 00:00:00 2001 From: Sebastiaan la Fleur Date: Thu, 12 Sep 2024 16:37:51 +0200 Subject: [PATCH 4/4] add_conn: Fix all linting, testing and typing issues. --- ci/lint.sh | 2 +- ci/typecheck.sh | 2 +- {src/s2python => examples}/example_frbc_rm.py | 10 +- src/s2python/common/__init__.py | 1 - src/s2python/reception_status_awaiter.py | 14 +- src/s2python/s2_connection.py | 182 +++++++++++------- src/s2python/s2_parser.py | 21 +- src/s2python/validate_values_mixin.py | 92 +-------- tests/unit/reception_status_awaiter_test.py | 182 ++++++------------ tests/unit/s2_connection_test.py | 65 +++++++ 10 files changed, 271 insertions(+), 300 deletions(-) rename {src/s2python => examples}/example_frbc_rm.py (93%) create mode 100644 tests/unit/s2_connection_test.py diff --git a/ci/lint.sh b/ci/lint.sh index 34ef27b..c405891 100755 --- a/ci/lint.sh +++ b/ci/lint.sh @@ -1,4 +1,4 @@ #!/usr/bin/env sh . .venv/bin/activate -pylint src/ tests/unit/ +pylint src/ tests/unit/ examples/ diff --git a/ci/typecheck.sh b/ci/typecheck.sh index 4fa4ffe..706035c 100755 --- a/ci/typecheck.sh +++ b/ci/typecheck.sh @@ -1,4 +1,4 @@ #!/usr/bin/env sh . .venv/bin/activate -mypy --config-file mypy.ini src/ ./tests/unit/ +mypy --config-file mypy.ini src/ ./tests/unit/ examples/ diff --git a/src/s2python/example_frbc_rm.py b/examples/example_frbc_rm.py similarity index 93% rename from src/s2python/example_frbc_rm.py rename to examples/example_frbc_rm.py index 3da66c2..bb05bc8 100644 --- a/src/s2python/example_frbc_rm.py +++ b/examples/example_frbc_rm.py @@ -30,7 +30,7 @@ ) from s2python.s2_connection import S2Connection, AssetDetails from s2python.s2_control_type import FRBCControlType, NoControlControlType - +from s2python.validate_values_mixin import S2Message logger = logging.getLogger("s2python") logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -39,8 +39,12 @@ class MyFRBCControlType(FRBCControlType): def handle_instruction( - self, conn: S2Connection, msg: FRBCInstruction, send_okay: Callable[[], None] + self, conn: S2Connection, msg: S2Message, send_okay: Callable[[], None] ) -> None: + if not isinstance(msg, FRBCInstruction): + raise RuntimeError( + f"Expected an FRBCInstruction but received a message of type {type(msg)}." + ) print(f"I have received the message {msg} from {conn}") def activate(self, conn: S2Connection) -> None: @@ -151,6 +155,8 @@ def deactivate(self, conn: S2Connection) -> None: instruction_processing_delay=Duration.from_milliseconds(20), roles=[Role(role=RoleType.ENERGY_CONSUMER, commodity=Commodity.ELECTRICITY)], currency=Currency.EUR, + provides_forecast=False, + provides_power_measurements=[CommodityQuantity.ELECTRIC_POWER_L1], ), ) diff --git a/src/s2python/common/__init__.py b/src/s2python/common/__init__.py index 4f099ed..806de7e 100644 --- a/src/s2python/common/__init__.py +++ b/src/s2python/common/__init__.py @@ -8,7 +8,6 @@ EnergyManagementRole, SessionRequestType, ControlType, - Currency, RevokableObjects, ) diff --git a/src/s2python/reception_status_awaiter.py b/src/s2python/reception_status_awaiter.py index 090cfd5..5c4bd42 100644 --- a/src/s2python/reception_status_awaiter.py +++ b/src/s2python/reception_status_awaiter.py @@ -1,6 +1,8 @@ """ReceptationStatusAwaiter class which notifies any coroutine waiting for a certain reception status message. -Copied from https://github.com/flexiblepower/s2-analyzer/blob/main/backend/s2_analyzer_backend/reception_status_awaiter.py under Apache2 license on 31-08-2024. +Copied from +https://github.com/flexiblepower/s2-analyzer/blob/main/backend/s2_analyzer_backend/reception_status_awaiter.py under +Apache2 license on 31-08-2024. """ import asyncio @@ -8,14 +10,13 @@ from typing import Dict from s2python.common import ReceptionStatus -from s2python.validate_values_mixin import S2Message class ReceptionStatusAwaiter: - received: Dict[uuid.UUID, S2Message] + received: Dict[uuid.UUID, ReceptionStatus] awaiting: Dict[uuid.UUID, asyncio.Event] - def __init__(self): + def __init__(self) -> None: self.received = {} self.awaiting = {} @@ -32,7 +33,7 @@ async def wait_for_reception_status( self.awaiting[message_id] = received_event await asyncio.wait_for(received_event.wait(), timeout_reception_status) - reception_status = self.received.get(message_id) + reception_status = self.received[message_id] if message_id in self.awaiting: del self.awaiting[message_id] @@ -47,7 +48,8 @@ async def receive_reception_status(self, reception_status: ReceptionStatus) -> N if reception_status.subject_message_id in self.received: raise RuntimeError( - f"ReceptationStatus for message_subject_id {reception_status.subject_message_id} has already been received!" + f"ReceptationStatus for message_subject_id {reception_status.subject_message_id} has already " + f"been received!" ) self.received[reception_status.subject_message_id] = reception_status diff --git a/src/s2python/s2_connection.py b/src/s2python/s2_connection.py index 1e6ef1b..28ac6da 100644 --- a/src/s2python/s2_connection.py +++ b/src/s2python/s2_connection.py @@ -1,5 +1,4 @@ import asyncio -import inspect import json import logging import threading @@ -33,11 +32,14 @@ @dataclass -class AssetDetails: +class AssetDetails: # pylint: disable=too-many-instance-attributes resource_id: str + provides_forecast: bool + provides_power_measurements: List[CommodityQuantity] + instruction_processing_delay: Duration - roles: List[Role] = None + roles: List[Role] currency: Optional[Currency] = None name: Optional[str] = None @@ -60,8 +62,8 @@ def to_resource_manager_details( message_id=uuid.uuid4(), model=self.model, name=self.name, - provides_forecast=True, # TODO - provides_power_measurement_types=[CommodityQuantity.ELECTRIC_POWER_L1], # TODO + provides_forecast=self.provides_forecast, + provides_power_measurement_types=self.provides_power_measurements, resource_id=self.resource_id, roles=self.roles, serial_number=self.serial_number, @@ -74,10 +76,59 @@ def to_resource_manager_details( ] +class SendOkay: + status_is_send: threading.Event + connection: "S2Connection" + subject_message_id: uuid.UUID + + def __init__(self, connection: "S2Connection", subject_message_id: uuid.UUID): + self.status_is_send = threading.Event() + self.connection = connection + self.subject_message_id = subject_message_id + + async def run_async(self) -> None: + self.status_is_send.set() + + await self.connection.respond_with_reception_status( + subject_message_id=str(self.subject_message_id), + status=ReceptionStatusValues.OK, + diagnostic_label="Processed okay.", + ) + + def run_sync(self) -> None: + self.status_is_send.set() + + self.connection.respond_with_reception_status_sync( + subject_message_id=str(self.subject_message_id), + status=ReceptionStatusValues.OK, + diagnostic_label="Processed okay.", + ) + + async def ensure_send_async(self, type_msg: Type[S2Message]) -> None: + if not self.status_is_send.is_set(): + logger.warning( + "Handler for message %s %s did not call send_okay / function to send the ReceptionStatus. " + "Sending it now.", + type_msg, + self.subject_message_id, + ) + await self.run_async() + + def ensure_send_sync(self, type_msg: Type[S2Message]) -> None: + if not self.status_is_send.is_set(): + logger.warning( + "Handler for message %s %s did not call send_okay / function to send the ReceptionStatus. " + "Sending it now.", + type_msg, + self.subject_message_id, + ) + self.run_sync() + + class MessageHandlers: handlers: Dict[Type[S2Message], S2MessageHandler] - def __init__(self): + def __init__(self) -> None: self.handlers = {} async def handle_message(self, connection: "S2Connection", msg: S2Message) -> None: @@ -87,60 +138,28 @@ async def handle_message(self, connection: "S2Connection", msg: S2Message) -> No :param msg: The S2 message """ handler = self.handlers.get(type(msg)) - if handler: - status_is_send = threading.Event() + if handler is not None: + send_okay = SendOkay(connection, msg.message_id) # type: ignore[attr-defined] try: - if inspect.iscoroutinefunction(handler): - - async def send_okay(): - status_is_send.set() - - await connection.respond_with_reception_status( - subject_message_id=str(msg.message_id), - status=ReceptionStatusValues.OK, - diagnostic_label="Processed okay.", - ) - - await handler(connection, msg, send_okay()) - - if not status_is_send.is_set(): - logger.warning( - "Handler for message %s did not call send_okay / function to send the ReceptionStatus. " - "Sending it now.", - type(msg), - ) - await send_okay() + if asyncio.iscoroutinefunction(handler): + await handler(connection, msg, send_okay.run_async()) # type: ignore[arg-type] + await send_okay.ensure_send_async(type(msg)) else: - eventloop = asyncio.get_event_loop() - def do_message(): - def send_okay(): - status_is_send.set() - - connection.respond_with_reception_status_sync( - subject_message_id=str(msg.message_id), - status=ReceptionStatusValues.OK, - diagnostic_label="Processed okay.", - ) - - handler(connection, msg, send_okay) - - if not status_is_send.is_set(): - logger.warning( - "Handler for message %s did not call send_okay / function to send the ReceptionStatus. " - "Sending it now.", - type(msg), - ) - send_okay() + def do_message() -> None: + handler(connection, msg, send_okay.run_sync) # type: ignore[arg-type] + send_okay.ensure_send_sync(type(msg)) + eventloop = asyncio.get_event_loop() await eventloop.run_in_executor(executor=None, func=do_message) except Exception: - if not status_is_send.is_set(): + if not send_okay.status_is_send.is_set(): await connection.respond_with_reception_status( - subject_message_id=str(msg.message_id), + subject_message_id=str(msg.message_id), # type: ignore[attr-defined] status=ReceptionStatusValues.PERMANENT_ERROR, - diagnostic_label=f"While processing message {msg.message_id} an unrecoverable error occurred.", + diagnostic_label=f"While processing message {msg.message_id} " # type: ignore[attr-defined] + f"an unrecoverable error occurred.", ) raise else: @@ -158,7 +177,7 @@ def register_handler(self, msg_type: Type[S2Message], handler: S2MessageHandler) self.handlers[msg_type] = handler -class S2Connection: +class S2Connection: # pylint: disable=too-many-instance-attributes url: str reception_status_awaiter: ReceptionStatusAwaiter ws: Optional[WSConnection] @@ -174,8 +193,7 @@ class S2Connection: _received_messages: asyncio.Queue _eventloop: asyncio.AbstractEventLoop - _receiver_task: Optional[asyncio.Task] - _handle_messages_task: Optional[asyncio.Task] + _background_tasks: Optional[asyncio.Task] _stop_event: asyncio.Event def __init__( @@ -184,7 +202,7 @@ def __init__( role: EnergyManagementRole, control_types: List[S2ControlType], asset_details: AssetDetails, - ): + ) -> None: self.url = url self.reception_status_awaiter = ReceptionStatusAwaiter() self.s2_parser = S2Parser() @@ -193,8 +211,7 @@ def __init__( self._current_control_type = None self._eventloop = asyncio.new_event_loop() - self._receiver_task = None - self._handle_messages_task = None + self._background_tasks = None self.control_types = control_types self.role = role @@ -214,9 +231,21 @@ def _run_eventloop(self) -> None: self._eventloop.run_until_complete(self._run_as_rm()) def stop(self) -> None: - asyncio.run_coroutine_threadsafe(self._do_stop(), self._eventloop) # TODO .result() + """Stops the S2 connection. + + Note: Ensure this method is called from a different thread than the thread running the S2 connection. + Otherwise it will block waiting on the coroutine _do_stop to terminate successfully but it can't run + the coroutine. A `RuntimeError` will be raised to prevent the indefinite block. + """ + if threading.current_thread() == self._thread: + raise RuntimeError( + "Do not call stop from the thread running the S2 connection. This results in an " + "infinite block!" + ) - async def _do_stop(self): + asyncio.run_coroutine_threadsafe(self._do_stop(), self._eventloop).result() + + async def _do_stop(self) -> None: logger.info("Will stop the S2 connection.") if self._background_tasks: self._background_tasks.cancel() @@ -226,7 +255,7 @@ async def _do_stop(self): await self.ws.close() await self.ws.wait_closed() - async def _run_as_rm(self): + async def _run_as_rm(self) -> None: logger.debug("Connecting as S2 resource manager.") self._received_messages = asyncio.Queue() await self.connect_ws() @@ -310,11 +339,11 @@ async def handle_select_control_type_as_rm( logger.debug("CEM selected control type %s. Activating control type.", message.control_type) - selected_control_type: S2ControlType = next( - filter( - lambda c: c.get_protocol_control_type() == message.control_type, self.control_types - ), - None, + control_types_by_protocol_name = { + c.get_protocol_control_type(): c for c in self.control_types + } + selected_control_type: Optional[S2ControlType] = control_types_by_protocol_name.get( + message.control_type ) if self._current_control_type is not None: @@ -332,7 +361,13 @@ async def _receive_messages(self) -> None: Will also receive the ReceptionStatus messages but instead of yielding these messages, they are routed to any calls of `send_msg_and_await_reception_status`. """ + if self.ws is None: + raise RuntimeError( + "Cannot receive messages if websocket connection is not yet established." + ) + logger.info("S2 connection has started to receive messages.") + async for message in self.ws: try: s2_msg: S2Message = self.s2_parser.parse_as_any_message(message) @@ -345,7 +380,7 @@ async def _receive_messages(self) -> None: ) ) except S2ValidationError as e: - json_msg = json.load(message) + json_msg = json.loads(message) message_id = json_msg.get("message_id") if message_id: await self.respond_with_reception_status( @@ -372,6 +407,11 @@ async def _receive_messages(self) -> None: await self._received_messages.put(s2_msg) async def _send_and_forget(self, s2_msg: S2Message) -> None: + if self.ws is None: + raise RuntimeError( + "Cannot send messages if websocket connection is not yet established." + ) + json_msg = s2_msg.to_json() logger.debug("Sending message %s", json_msg) await self.ws.send(json_msg) @@ -398,15 +438,15 @@ def respond_with_reception_status_sync( async def send_msg_and_await_reception_status_async( self, s2_msg: S2Message, timeout_reception_status: float = 5.0, raise_on_error: bool = True - ) -> S2Message: + ) -> ReceptionStatus: await self._send_and_forget(s2_msg) logger.debug( "Waiting for ReceptionStatus for %s %s seconds", - s2_msg.message_id, + s2_msg.message_id, # type: ignore[attr-defined] timeout_reception_status, ) reception_status = await self.reception_status_awaiter.wait_for_reception_status( - s2_msg.message_id, timeout_reception_status + s2_msg.message_id, timeout_reception_status # type: ignore[attr-defined] ) if reception_status.status != ReceptionStatusValues.OK and raise_on_error: @@ -416,8 +456,8 @@ async def send_msg_and_await_reception_status_async( def send_msg_and_await_reception_status_sync( self, s2_msg: S2Message, timeout_reception_status: float = 5.0, raise_on_error: bool = True - ) -> S2Message: - asyncio.run_coroutine_threadsafe( + ) -> ReceptionStatus: + return asyncio.run_coroutine_threadsafe( self.send_msg_and_await_reception_status_async( s2_msg, timeout_reception_status, raise_on_error ), diff --git a/src/s2python/s2_parser.py b/src/s2python/s2_parser.py index 99433b6..906a286 100644 --- a/src/s2python/s2_parser.py +++ b/src/s2python/s2_parser.py @@ -24,18 +24,18 @@ FRBCTimerStatus, FRBCUsageForecast, ) -from s2python.validate_values_mixin import SupportsValidation +from s2python.validate_values_mixin import S2Message from s2python.s2_validation_error import S2ValidationError LOGGER = logging.getLogger(__name__) S2MessageType = str -M = TypeVar("M", bound=SupportsValidation) +M = TypeVar("M", bound=S2Message) # May be generated with development_utilities/generate_s2_message_type_to_class.py -TYPE_TO_MESSAGE_CLASS: Dict[str, Type[SupportsValidation]] = { +TYPE_TO_MESSAGE_CLASS: Dict[str, Type[S2Message]] = { "FRBC.ActuatorStatus": FRBCActuatorStatus, "FRBC.FillLevelTargetProfile": FRBCFillLevelTargetProfile, "FRBC.Instruction": FRBCInstruction, @@ -59,13 +59,13 @@ class S2Parser: @staticmethod - def _parse_json_if_required(unparsed_message: Union[dict, str]) -> dict: - if isinstance(unparsed_message, str): + def _parse_json_if_required(unparsed_message: Union[dict, str, bytes]) -> dict: + if isinstance(unparsed_message, (str, bytes)): return json.loads(unparsed_message) return unparsed_message @staticmethod - def parse_as_any_message(unparsed_message: Union[dict, str]) -> SupportsValidation: + def parse_as_any_message(unparsed_message: Union[dict, str, bytes]) -> S2Message: """Parse the message as any S2 python message regardless of message type. :param unparsed_message: The message as a JSON-formatted string or as a json-parsed dictionary. @@ -77,13 +77,16 @@ def parse_as_any_message(unparsed_message: Union[dict, str]) -> SupportsValidati if message_type not in TYPE_TO_MESSAGE_CLASS: raise S2ValidationError( - None, message_json, f"Unable to parse {message_type} as an S2 message. Type unknown.", None + None, + message_json, + f"Unable to parse {message_type} as an S2 message. Type unknown.", + None, ) return TYPE_TO_MESSAGE_CLASS[message_type].model_validate(message_json) @staticmethod - def parse_as_message(unparsed_message: Union[dict, str], as_message: Type[M]) -> M: + def parse_as_message(unparsed_message: Union[dict, str, bytes], as_message: Type[M]) -> M: """Parse the message to a specific S2 python message. :param unparsed_message: The message as a JSON-formatted string or as a JSON-parsed dictionary. @@ -95,7 +98,7 @@ def parse_as_message(unparsed_message: Union[dict, str], as_message: Type[M]) -> return as_message.from_dict(message_json) @staticmethod - def parse_message_type(unparsed_message: Union[dict, str]) -> Optional[S2MessageType]: + def parse_message_type(unparsed_message: Union[dict, str, bytes]) -> Optional[S2MessageType]: """Parse only the message type from the unparsed message. This is useful to call before `parse_as_message` to retrieve the message type and allows for strictly-typed diff --git a/src/s2python/validate_values_mixin.py b/src/s2python/validate_values_mixin.py index c240610..7d0d9d6 100644 --- a/src/s2python/validate_values_mixin.py +++ b/src/s2python/validate_values_mixin.py @@ -1,23 +1,6 @@ -import uuid -from typing import ( - TypeVar, - Generic, - Protocol, - Type, - Optional, - Callable, - Any, - Union, - AbstractSet, - Mapping, - List, - Dict, - Literal, -) -from typing_extensions import Self +from typing import TypeVar, Generic, Type, Callable, Any, Union, AbstractSet, Mapping, List, Dict from pydantic import BaseModel, ValidationError # pylint: disable=no-name-in-module -from pydantic.main import IncEx from pydantic.v1.error_wrappers import display_errors # pylint: disable=no-name-in-module from s2python.s2_validation_error import S2ValidationError @@ -29,75 +12,10 @@ MappingIntStrAny = Mapping[IntStr, Any] -class SupportsValidation(Protocol[B_co]): - # ValidateValuesMixin methods - def to_json(self) -> str: ... +C = TypeVar("C", bound="BaseModel") - def to_dict(self) -> Dict: ... - @classmethod - def from_json(cls, json_str: str) -> B_co: ... - - @classmethod - def from_dict(cls, json_dict: Dict) -> B_co: ... - - # Pydantic methods - @classmethod - def model_validate_json( - cls, - json_data: Union[str, bytes, bytearray], - *, - strict: Optional[bool] = None, - context: Optional[Any] = None, - ) -> Self: ... - - @classmethod - def model_validate( - cls, - obj: Any, - *, - strict: Optional[bool] = None, - from_attributes: Optional[bool] = None, - context: Optional[Any] = None, - ) -> Self: ... - - def model_dump( # pylint: disable=too-many-arguments - self, - *, - mode: Union[Literal["json", "python"], str] = "python", - include: IncEx = None, - exclude: IncEx = None, - context: Optional[Any] = None, - by_alias: bool = False, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: Union[bool, Literal["none", "warn", "error"]] = True, - serialize_as_any: bool = False, - ) -> Dict[str, Any]: ... - - def model_dump_json( # pylint: disable=too-many-arguments - self, - *, - indent: Optional[int] = None, - include: IncEx = None, - exclude: IncEx = None, - context: Optional[Any] = None, - by_alias: bool = False, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: Union[bool, Literal["none", "warn", "error"]] = True, - serialize_as_any: bool = False, - ) -> str: ... - - -C = TypeVar("C", bound="SupportsValidation") - - -class S2Message(Generic[C]): +class S2Message(BaseModel, Generic[C]): def to_json(self: C) -> str: try: return self.model_dump_json(by_alias=True, exclude_none=True) @@ -141,9 +59,7 @@ def inner(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return inner -def catch_and_convert_exceptions( - input_class: Type[SupportsValidation[B_co]], -) -> Type[SupportsValidation[B_co]]: +def catch_and_convert_exceptions(input_class: Type[S2Message[B_co]]) -> Type[S2Message[B_co]]: input_class.__init__ = convert_to_s2exception(input_class.__init__) # type: ignore[method-assign] input_class.__setattr__ = convert_to_s2exception(input_class.__setattr__) # type: ignore[method-assign] input_class.model_validate_json = convert_to_s2exception( # type: ignore[method-assign] diff --git a/tests/unit/reception_status_awaiter_test.py b/tests/unit/reception_status_awaiter_test.py index 6bfd92e..167966d 100644 --- a/tests/unit/reception_status_awaiter_test.py +++ b/tests/unit/reception_status_awaiter_test.py @@ -1,12 +1,21 @@ """Tests for ReceptionStatusAwaiter. -Copied from https://github.com/flexiblepower/s2-analyzer/blob/main/backend/test/s2_analyzer_backend/reception_status_awaiter_test.py under Apache2 license on 31-08-2024. +Copied from +https://github.com/flexiblepower/s2-analyzer/blob/main/backend/test/s2_analyzer_backend/reception_status_awaiter_test.py +under Apache2 license on 31-08-2024. """ import asyncio +import datetime +import uuid from unittest import IsolatedAsyncioTestCase -from unittest.mock import AsyncMock, Mock +from s2python.common import ( + ReceptionStatus, + ReceptionStatusValues, + InstructionStatus, + InstructionStatusUpdate, +) from s2python.reception_status_awaiter import ReceptionStatusAwaiter @@ -14,26 +23,22 @@ class ReceptionStatusAwaiterTest(IsolatedAsyncioTestCase): async def test__wait_for_reception_status__receive_while_waiting(self): # Arrange awaiter = ReceptionStatusAwaiter() - message_id = "1" - s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": message_id, - "status": "OK", - } + message_id = uuid.uuid4() + s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) # Act - wait_task = asyncio.create_task(awaiter.wait_for_reception_status(message_id)) + wait_task = asyncio.create_task(awaiter.wait_for_reception_status(message_id, 1.0)) should_be_waiting_still = not wait_task.done() await awaiter.receive_reception_status(s2_reception_status) await wait_task received_s2_reception_status = wait_task.result() # Assert - expected_s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": "1", - "status": "OK", - } + expected_s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) self.assertTrue(should_be_waiting_still) self.assertEqual(expected_s2_reception_status, received_s2_reception_status) @@ -41,38 +46,32 @@ async def test__wait_for_reception_status__receive_while_waiting(self): async def test__wait_for_reception_status__already_received(self): # Arrange awaiter = ReceptionStatusAwaiter() - message_id = "1" - s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": message_id, - "status": "OK", - } + message_id = uuid.uuid4() + s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) # Act await awaiter.receive_reception_status(s2_reception_status) - received_s2_reception_status = await awaiter.wait_for_reception_status(message_id) + received_s2_reception_status = await awaiter.wait_for_reception_status(message_id, 1.0) # Assert - expected_s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": "1", - "status": "OK", - } + expected_s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) self.assertEqual(expected_s2_reception_status, received_s2_reception_status) async def test__wait_for_reception_status__multiple_receive_while_waiting(self): # Arrange awaiter = ReceptionStatusAwaiter() - message_id = "1" - s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": message_id, - "status": "OK", - } + message_id = uuid.uuid4() + s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) # Act - wait_task_1 = asyncio.create_task(awaiter.wait_for_reception_status(message_id)) - wait_task_2 = asyncio.create_task(awaiter.wait_for_reception_status(message_id)) + wait_task_1 = asyncio.create_task(awaiter.wait_for_reception_status(message_id, 1.0)) + wait_task_2 = asyncio.create_task(awaiter.wait_for_reception_status(message_id, 1.0)) should_be_waiting_still_1 = not wait_task_1.done() should_be_waiting_still_2 = not wait_task_2.done() await awaiter.receive_reception_status(s2_reception_status) @@ -82,11 +81,9 @@ async def test__wait_for_reception_status__multiple_receive_while_waiting(self): received_s2_reception_status_2 = wait_task_2.result() # Assert - expected_s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": "1", - "status": "OK", - } + expected_s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) self.assertTrue(should_be_waiting_still_1) self.assertTrue(should_be_waiting_still_2) @@ -96,20 +93,23 @@ async def test__wait_for_reception_status__multiple_receive_while_waiting(self): async def test__receive_reception_status__wrong_message(self): # Arrange awaiter = ReceptionStatusAwaiter() - s2_msg = {"message_type": "NotAReceptionStatus", "subject_message_id": "1", "status": "OK"} + s2_msg = InstructionStatusUpdate( + message_id=uuid.uuid4(), + instruction_id=uuid.uuid4(), + status_type=InstructionStatus.NEW, + timestamp=datetime.datetime.now(datetime.timezone.utc), + ) # Act / Assert with self.assertRaises(RuntimeError): - await awaiter.receive_reception_status(s2_msg) + await awaiter.receive_reception_status(s2_msg) # type: ignore[arg-type] async def test__receive_reception_status__received_duplicate(self): # Arrange awaiter = ReceptionStatusAwaiter() - s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": "1", - "status": "OK", - } + s2_reception_status = ReceptionStatus( + subject_message_id=uuid.uuid4(), status=ReceptionStatusValues.OK + ) # Act / Assert await awaiter.receive_reception_status(s2_reception_status) @@ -119,18 +119,19 @@ async def test__receive_reception_status__received_duplicate(self): async def test__receive_reception_status__receive_no_awaiting(self): # Arrange awaiter = ReceptionStatusAwaiter() - s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": "1", - "status": "OK", - } + message_id = uuid.uuid4() + s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) # Act await awaiter.receive_reception_status(s2_reception_status) # Assert expected_received = { - "1": {"message_type": "ReceptionStatus", "subject_message_id": "1", "status": "OK"} + message_id: ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) } self.assertEqual(awaiter.received, expected_received) self.assertEqual(awaiter.awaiting, {}) @@ -139,12 +140,11 @@ async def test__receive_reception_status__receive_with_awaiting(self): # Arrange awaiter = ReceptionStatusAwaiter() awaiting_event = asyncio.Event() - awaiter.awaiting = {"1": awaiting_event} - s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": "1", - "status": "OK", - } + message_id = uuid.uuid4() + awaiter.awaiting = {message_id: awaiting_event} + s2_reception_status = ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) # Act should_not_be_set = not awaiting_event.is_set() @@ -153,72 +153,12 @@ async def test__receive_reception_status__receive_with_awaiting(self): # Assert expected_received = { - "1": {"message_type": "ReceptionStatus", "subject_message_id": "1", "status": "OK"} + message_id: ReceptionStatus( + subject_message_id=message_id, status=ReceptionStatusValues.OK + ) } self.assertTrue(should_not_be_set) self.assertTrue(should_be_set) self.assertEqual(awaiter.received, expected_received) self.assertEqual(awaiter.awaiting, {}) - - async def test__send_and_await_reception_status__receive_while_waiting(self): - # Arrange - conn = Mock() - awaiter = ReceptionStatusAwaiter() - message_id = "1" - s2_message = { - "message_type": "Handshake", - "message_id": message_id, - "role": "RM", - "supported_protocol_versions": ["1.0"], - } - s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": message_id, - "status": "OK", - } - - # Act - wait_task = asyncio.create_task( - awaiter.send_and_await_reception_status(conn, s2_message, True) - ) - should_be_waiting_still = not wait_task.done() - await awaiter.receive_reception_status(s2_reception_status) - await wait_task - received_s2_reception_status = wait_task.result() - - # Assert - expected_s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": "1", - "status": "OK", - } - - self.assertTrue(should_be_waiting_still) - self.assertEqual(expected_s2_reception_status, received_s2_reception_status) - - async def test__send_and_await_reception_status__receive_while_waiting_not_okay(self): - # Arrange - conn = Mock() - awaiter = ReceptionStatusAwaiter() - message_id = "1" - s2_message = { - "message_type": "Handshake", - "message_id": message_id, - "role": "RM", - "supported_protocol_versions": ["1.0"], - } - s2_reception_status = { - "message_type": "ReceptionStatus", - "subject_message_id": message_id, - "status": "INVALID_MESSAGE", - } - - # Act / Assert - wait_task = asyncio.create_task( - awaiter.send_and_await_reception_status(conn, s2_message, True) - ) - await awaiter.receive_reception_status(s2_reception_status) - - with self.assertRaises(RuntimeError): - await wait_task diff --git a/tests/unit/s2_connection_test.py b/tests/unit/s2_connection_test.py new file mode 100644 index 0000000..fcb8b37 --- /dev/null +++ b/tests/unit/s2_connection_test.py @@ -0,0 +1,65 @@ +# import unittest +# +# +# class S2ConnectionTest(unittest.TestCase): +# async def test__send_and_await_reception_status__receive_while_waiting(self): +# # Arrange +# conn = Mock() +# awaiter = ReceptionStatusAwaiter() +# message_id = "1" +# s2_message = { +# "message_type": "Handshake", +# "message_id": message_id, +# "role": "RM", +# "supported_protocol_versions": ["1.0"], +# } +# s2_reception_status = { +# "message_type": "ReceptionStatus", +# "subject_message_id": message_id, +# "status": "OK", +# } +# +# # Act +# wait_task = asyncio.create_task( +# awaiter.send_and_await_reception_status(conn, s2_message, True) +# ) +# should_be_waiting_still = not wait_task.done() +# await awaiter.receive_reception_status(s2_reception_status) +# await wait_task +# received_s2_reception_status = wait_task.result() +# +# # Assert +# expected_s2_reception_status = { +# "message_type": "ReceptionStatus", +# "subject_message_id": "1", +# "status": "OK", +# } +# +# self.assertTrue(should_be_waiting_still) +# self.assertEqual(expected_s2_reception_status, received_s2_reception_status) +# +# async def test__send_and_await_reception_status__receive_while_waiting_not_okay(self): +# # Arrange +# conn = Mock() +# awaiter = ReceptionStatusAwaiter() +# message_id = "1" +# s2_message = { +# "message_type": "Handshake", +# "message_id": message_id, +# "role": "RM", +# "supported_protocol_versions": ["1.0"], +# } +# s2_reception_status = { +# "message_type": "ReceptionStatus", +# "subject_message_id": message_id, +# "status": "INVALID_MESSAGE", +# } +# +# # Act / Assert +# wait_task = asyncio.create_task( +# awaiter.send_and_await_reception_status(conn, s2_message, True) +# ) +# await awaiter.receive_reception_status(s2_reception_status) +# +# with self.assertRaises(RuntimeError): +# await wait_task