Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
build-backend = "setuptools.build_meta"

[project]
name = "s2-python"
description = "S2 Protocol Python Implementation"
version = "0.5.0"

[project.optional-dependencies]
ws = ["websockets"]
fastapi = ["fastapi"]
flask = ["Flask"]
Comment on lines +5 to +13
Copy link
Collaborator

@Flix6x Flix6x Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of this was already defined in setup.cfg. Let's keep it there (see this related PR). you were correct in that pyproject.toml is the way to go rather than setup.cfg.

73 changes: 73 additions & 0 deletions src/s2python/authorization/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from abc import ABC, abstractmethod
from typing import Any, Dict


class AbstractConnectionClient(ABC):
"""Abstract class for handling the /requestConnection endpoint."""

def request_connection(self) -> Any:
"""Orchestrate the connection request flow: build → execute → handle."""
request_data = self.build_connection_request()
response_data = self.execute_connection_request(request_data)
return self.handle_connection_response(response_data)

@abstractmethod
def build_connection_request(self) -> Dict:
"""
Build the payload for the ConnectionRequest schema.
Returns a dictionary with keys: s2ClientNodeId, supportedProtocols.
"""
pass

@abstractmethod
def execute_connection_request(self, request_data: Dict) -> Dict:
"""
Execute the POST request to /requestConnection.
Implementations should send the request_data to the endpoint
and return the JSON response as a dictionary.
"""
pass

@abstractmethod
def handle_connection_response(self, response_data: Dict) -> Any:
"""
Process the ConnectionDetails response (e.g., extract challenge and connection URI).
The response_data contains keys: selectedProtocol, challenge, connectionUri.
"""
pass


class AbstractPairingClient(ABC):
"""Abstract class for handling the /requestPairing endpoint."""

def request_pairing(self) -> Any:
"""Orchestrate the pairing request flow: build → execute → handle."""
request_data = self.build_pairing_request()
response_data = self.execute_pairing_request(request_data)
return self.handle_pairing_response(response_data)

@abstractmethod
def build_pairing_request(self) -> Dict:
"""
Build the payload for the PairingRequest schema.
Returns a dictionary with keys: token, publicKey, s2ClientNodeId,
s2ClientNodeDescription, supportedProtocols.
"""
pass

@abstractmethod
def execute_pairing_request(self, request_data: Dict) -> Dict:
"""
Execute the POST request to /requestPairing.
Implementations should send the request_data to the endpoint
and return the JSON response as a dictionary.
"""
pass

@abstractmethod
def handle_pairing_response(self, response_data: Dict) -> Any:
"""
Process the PairingResponse (e.g., extract server details).
The response_data contains keys: s2ServerNodeId, serverNodeDescription, requestConnectionUri.
"""
pass
Empty file.
Empty file.
Empty file.
Empty file.
173 changes: 173 additions & 0 deletions src/s2python/generated/gen_s2_pairing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""
Generated classes based on s2-over-ip-pairing.yaml OpenAPI schema.
This file is auto-generated and should not be modified directly.
"""

import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto
from typing import List, Optional


class Protocols(str, Enum):
"""Supported protocol types."""

WebSocketSecure = "WebSocketSecure"


class S2Role(str, Enum):
"""Roles in the S2 protocol."""

CEM = "CEM"
RM = "RM"


class Deployment(str, Enum):
"""Deployment types."""

WAN = "WAN"
LAN = "LAN"


@dataclass
class S2NodeDescription:
"""Description of an S2 node."""

brand: Optional[str] = None
logoUri: Optional[str] = None
type: Optional[str] = None
modelName: Optional[str] = None
userDefinedName: Optional[str] = None
role: Optional[S2Role] = None
deployment: Optional[Deployment] = None


class PairingToken(str):
"""A token used for pairing.

Must match pattern: ^[0-9a-zA-Z]{32}$
"""

def __new__(cls, content: str):
import re

if not re.match(r"^[0-9a-zA-Z]{32}$", content):
raise ValueError("PairingToken must be 32 alphanumeric characters")
return super().__new__(cls, content)


@dataclass
class PairingInfo:
"""Information about a pairing."""

pairingUri: Optional[str] = None
token: Optional[PairingToken] = None
validUntil: Optional[datetime] = None


@dataclass
class PairingRequest:
"""Request to initiate pairing."""

token: Optional[PairingToken] = None
publicKey: Optional[bytes] = None
s2ClientNodeId: Optional[uuid.UUID] = None
s2ClientNodeDescription: Optional[S2NodeDescription] = None
supportedProtocols: Optional[List[Protocols]] = None


@dataclass
class PairingResponse:
"""Response to a pairing request."""

s2ServerNodeId: Optional[uuid.UUID] = None
serverNodeDescription: Optional[S2NodeDescription] = None
requestConnectionUri: Optional[str] = None


@dataclass
class ConnectionRequest:
"""Request to establish a connection."""

s2ClientNodeId: Optional[uuid.UUID] = None
supportedProtocols: Optional[List[Protocols]] = None


@dataclass
class ConnectionDetails:
"""Details for establishing a connection."""

selectedProtocol: Optional[Protocols] = None
challenge: Optional[bytes] = None
connectionUri: Optional[str] = None


# Serialization/Deserialization functions


def _is_dataclass_instance(obj):
"""Check if an object is a dataclass instance."""
from dataclasses import is_dataclass

return is_dataclass(obj) and not isinstance(obj, type)


def to_dict(obj):
"""Convert a dataclass instance to a dictionary."""
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, bytes):
import base64

return base64.b64encode(obj).decode("ascii")
elif isinstance(obj, Enum):
return obj.value
elif isinstance(obj, list):
return [to_dict(item) for item in obj]
elif _is_dataclass_instance(obj):
result = {}
for field in obj.__dataclass_fields__:
value = getattr(obj, field)
if value is not None:
result[field] = to_dict(value)
return result
else:
return obj


def from_dict(cls, data):
"""Create a dataclass instance from a dictionary."""
if data is None:
return None

if cls is datetime:
return datetime.fromisoformat(data)
elif cls is uuid.UUID:
return uuid.UUID(data)
elif cls is bytes:
import base64

return base64.b64decode(data.encode("ascii"))
elif issubclass(cls, Enum):
return cls(data)
elif issubclass(cls, PairingToken):
return PairingToken(data)
elif hasattr(cls, "__dataclass_fields__"):
fieldtypes = cls.__annotations__
instance_data = {}

for field, field_type in fieldtypes.items():
if field in data and data[field] is not None:
# Handle List[Type] annotations
if hasattr(field_type, "__origin__") and field_type.__origin__ is list:
item_type = field_type.__args__[0]
instance_data[field] = [from_dict(item_type, item) for item in data[field]]
else:
instance_data[field] = from_dict(field_type, data[field])

return cls(**instance_data)
else:
return data
55 changes: 17 additions & 38 deletions src/s2python/s2_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from dataclasses import dataclass
from typing import Any, Optional, List, Type, Dict, Callable, Awaitable, Union

import websockets
from websockets.asyncio.client import ClientConnection as WSConnection, connect as ws_connect
try:
import websockets
from websockets.asyncio.client import ClientConnection as WSConnection, connect as ws_connect
except ImportError:
raise ImportError("You need to run 'pip install s2-python[ws]' to use this feature.")

from s2python.common import (
ReceptionStatusValues,
Expand Down Expand Up @@ -51,13 +54,9 @@ class AssetDetails: # pylint: disable=too-many-instance-attributes
firmware_version: Optional[str] = None
serial_number: Optional[str] = None

def to_resource_manager_details(
self, control_types: List[S2ControlType]
) -> ResourceManagerDetails:
def to_resource_manager_details(self, control_types: List[S2ControlType]) -> ResourceManagerDetails:
return ResourceManagerDetails(
available_control_types=[
control_type.get_protocol_control_type() for control_type in control_types
],
available_control_types=[control_type.get_protocol_control_type() for control_type in control_types],
currency=self.currency,
firmware_version=self.firmware_version,
instruction_processing_delay=self.instruction_processing_delay,
Expand Down Expand Up @@ -298,9 +297,7 @@ async def wait_till_connection_restart() -> None:
self._eventloop.create_task(wait_till_connection_restart()),
]

(done, pending) = await asyncio.wait(
background_tasks, return_when=asyncio.FIRST_COMPLETED
)
(done, pending) = await asyncio.wait(background_tasks, return_when=asyncio.FIRST_COMPLETED)
if self._current_control_type:
self._current_control_type.deactivate(self)
self._current_control_type = None
Expand Down Expand Up @@ -333,31 +330,23 @@ async def _connect_ws(self) -> None:
connection_kwargs["ssl"].verify_mode = ssl.CERT_NONE

if self._bearer_token:
connection_kwargs["additional_headers"] = {
"Authorization": f"Bearer {self._bearer_token}"
}
connection_kwargs["additional_headers"] = {"Authorization": f"Bearer {self._bearer_token}"}

self.ws = await ws_connect(uri=self.url, **connection_kwargs)
except (EOFError, OSError) as e:
logger.info("Could not connect due to: %s", str(e))

async def _connect_as_rm(self) -> None:
await self.send_msg_and_await_reception_status_async(
Handshake(
message_id=uuid.uuid4(), role=self.role, supported_protocol_versions=[S2_VERSION]
)
Handshake(message_id=uuid.uuid4(), role=self.role, supported_protocol_versions=[S2_VERSION])
)
logger.debug("Send handshake to CEM. Expecting Handshake and HandshakeResponse from CEM.")

await self._handle_received_messages()

async def handle_handshake(
self, _: "S2Connection", message: S2Message, send_okay: Awaitable[None]
) -> None:
async def handle_handshake(self, _: "S2Connection", message: S2Message, send_okay: Awaitable[None]) -> None:
if not isinstance(message, Handshake):
logger.error(
"Handler for Handshake received a message of the wrong type: %s", type(message)
)
logger.error("Handler for Handshake received a message of the wrong type: %s", type(message))
return

logger.debug(
Expand Down Expand Up @@ -401,12 +390,8 @@ async def handle_select_control_type_as_rm(

logger.debug("CEM selected control type %s. Activating control type.", message.control_type)

control_types_by_protocol_name = {
c.get_protocol_control_type(): c for c in self.control_types
}
selected_control_type: Optional[S2ControlType] = control_types_by_protocol_name.get(
message.control_type
)
control_types_by_protocol_name = {c.get_protocol_control_type(): c for c in self.control_types}
selected_control_type: Optional[S2ControlType] = control_types_by_protocol_name.get(message.control_type)

if self._current_control_type is not None:
await self._eventloop.run_in_executor(None, self._current_control_type.deactivate, self)
Expand All @@ -424,9 +409,7 @@ async def _receive_messages(self) -> None:
to any calls of `send_msg_and_await_reception_status`.
"""
if self.ws is None:
raise RuntimeError(
"Cannot receive messages if websocket connection is not yet established."
)
raise RuntimeError("Cannot receive messages if websocket connection is not yet established.")

logger.info("S2 connection has started to receive messages.")

Expand Down Expand Up @@ -470,9 +453,7 @@ async def _receive_messages(self) -> None:

async def _send_and_forget(self, s2_msg: S2Message) -> None:
if self.ws is None:
raise RuntimeError(
"Cannot send messages if websocket connection is not yet established."
)
raise RuntimeError("Cannot send messages if websocket connection is not yet established.")

json_msg = s2_msg.to_json()
logger.debug("Sending message %s", json_msg)
Expand Down Expand Up @@ -532,9 +513,7 @@ def send_msg_and_await_reception_status_sync(
self, s2_msg: S2Message, timeout_reception_status: float = 5.0, raise_on_error: bool = True
) -> ReceptionStatus:
return asyncio.run_coroutine_threadsafe(
self.send_msg_and_await_reception_status_async(
s2_msg, timeout_reception_status, raise_on_error
),
self.send_msg_and_await_reception_status_async(s2_msg, timeout_reception_status, raise_on_error),
self._eventloop,
).result()

Expand Down
Loading