diff --git a/examples/base/game.py b/examples/base/game.py index 3558db3..8bb5be6 100644 --- a/examples/base/game.py +++ b/examples/base/game.py @@ -116,4 +116,3 @@ def valid_step(ctx): ctx.visual.simulate() rule_terminate(ctx) - diff --git a/gamms/AgentEngine/agent_engine.py b/gamms/AgentEngine/agent_engine.py index f71dda2..ce767b1 100644 --- a/gamms/AgentEngine/agent_engine.py +++ b/gamms/AgentEngine/agent_engine.py @@ -187,9 +187,6 @@ def orientation(self) -> float: return 0.0 angle_rad = math.atan2(delta_y, delta_x) return math.degrees(angle_rad) % 360 - - - class AgentEngine(IAgentEngine): def __init__(self, ctx: IContext): @@ -203,8 +200,9 @@ def create_agent(self, name, **kwargs): if self.ctx.record.record(): self.ctx.record.write(opCode=OpCodes.AGENT_CREATE, data={"name": name, "kwargs": kwargs}) start_node_id = kwargs.pop('start_node_id') + sensors = kwargs.pop('sensors', []) agent = Agent(self.ctx, name, start_node_id, **kwargs) - for sensor in kwargs['sensors']: + for sensor in sensors: try: agent.register_sensor(sensor, self.ctx.sensor.get_sensor(sensor)) except KeyError: diff --git a/gamms/GraphEngine/visualization_engine.py b/gamms/Recorder/__init__.py similarity index 100% rename from gamms/GraphEngine/visualization_engine.py rename to gamms/Recorder/__init__.py diff --git a/gamms/Recorder/recorder.py b/gamms/Recorder/recorder.py new file mode 100644 index 0000000..331a137 --- /dev/null +++ b/gamms/Recorder/recorder.py @@ -0,0 +1,214 @@ +from typing import Union, BinaryIO, Callable, Dict, TypeVar, Type, Tuple, Iterator +from gamms.typing.recorder import IRecorder, JsonType +from gamms.typing.opcodes import OpCodes, MAGIC_NUMBER, VERSION +from gamms.typing import IContext +import os +import time +import ubjson +import typing +from gamms.Recorder.component import component +from io import IOBase + +_T = TypeVar('_T') + +def _record_switch_case(ctx: IContext, opCode: OpCodes, data: JsonType) -> None: + if opCode == OpCodes.AGENT_CREATE: + print(f"Creating agent {data['name']} at node {data['kwargs']['start_node_id']}") + ctx.agent.create_agent(data["name"], **data["kwargs"]) + elif opCode == OpCodes.AGENT_DELETE: + print(f"Deleting agent {data}") + ctx.agent.delete_agent(data) + elif opCode == OpCodes.SIMULATE: + ctx.visual.simulate() + elif opCode == OpCodes.AGENT_CURRENT_NODE: + print(f"Agent {data['agent_name']} moved to node {data['node_id']}") + ctx.agent.get_agent(data["agent_name"]).current_node_id = data["node_id"] + elif opCode == OpCodes.AGENT_PREV_NODE: + ctx.agent.get_agent(data["agent_name"]).prev_node_id = data["node_id"] + elif opCode == OpCodes.COMPONENT_REGISTER: + cls_key = tuple(data["key"]) + if ctx.record.is_component_registered(cls_key): + print(f"Component {cls_key} already registered.") + else: + print(f"Registering component {cls_key} of type {data['struct']}") + module, name = cls_key + cls_type = type(name, (object,), {}) + cls_type.__module__ = module + struct = {key: eval(value) for key, value in data["struct"].items()} + ctx.record.component(struct=struct)(cls_type) + elif opCode == OpCodes.COMPONENT_CREATE: + print(f"Creating component {data['name']} of type {data['type']}") + cls_key = tuple(data["type"]) + ctx.record._component_registry[cls_key](name=data["name"]) + elif opCode == OpCodes.COMPONENT_UPDATE: + print(f"Updating component {data['name']} with key {data['key']} to value {data['value']}") + obj = ctx.record.get_component(data["name"]) + setattr(obj, data["key"], data["value"]) + elif opCode == OpCodes.TERMINATE: + print("Terminating...") + else: + raise ValueError(f"Invalid opcode {opCode}") + +class Recorder(IRecorder): + def __init__(self, ctx: IContext): + self.ctx = ctx + self.is_recording = False + self.is_replaying = False + self.is_paused = False + self._fp_record = None + self._fp_replay = None + self._time = None + self._components: Dict[str, Type[_T]] = {} + self._component_registry: Dict[Tuple[str, str], Type[_T]] = {} + + def record(self) -> bool: + if not self.is_paused and self.is_recording and not self.ctx.is_terminated(): + return True + else: + return False + + def start(self, path: Union[str, BinaryIO]) -> None: + if self._fp_record is not None: + raise RuntimeError("Recording file is already open. Stop recording before starting a new one.") + + if isinstance(path, str): + # Check if path has extension .ggr + if not path.endswith('.ggr'): + path += '.ggr' + + if os.path.exists(path): + raise FileExistsError(f"File {path} already exists.") + + self._fp_record = open(path, 'wb') + elif isinstance(path, IOBase): + self._fp_record = path + else: + raise TypeError("Path must be a string or a file object.") + self.is_recording = True + self.is_paused = False + + # Add file validity header + self._fp_record.write(MAGIC_NUMBER) + self._fp_record.write(VERSION) + + def stop(self) -> None: + if not self.is_recording: + raise RuntimeError("Recording has not started.") + self.write(OpCodes.TERMINATE, None) + self.is_recording = False + self.is_paused = False + self._fp_record.close() + self._fp_record = None + + def pause(self) -> None: + if not self.is_recording: + print("Warning: Recording has not started.") + elif self.is_paused: + print("Warning: Recording is already paused.") + else: + self.is_paused = True + print("Recording paused.") + + def play(self) -> None: + if not self.is_recording: + print("Warning: Recording has not started.") + elif not self.is_paused: + print("Warning: Recording is already playing.") + else: + self.is_paused = False + print("Recording resumed.") + + def replay(self, path: Union[str, BinaryIO]): + if self._fp_replay is not None: + raise RuntimeError("Replay file is already open. Stop replaying before starting a new one.") + + if isinstance(path, str): + # Check if path has extension .ggr + if not path.endswith('.ggr'): + path += '.ggr' + + if not os.path.exists(path): + raise FileNotFoundError(f"File {path} does not exist.") + + self._fp_replay = open(path, 'rb') + elif isinstance(path, IOBase): + self._fp_replay = path + else: + raise TypeError("Path must be a string or a file object.") + + # Check file validity header + if self._fp_replay.read(4) != MAGIC_NUMBER: + raise ValueError("Invalid file format.") + + _version = self._fp_replay.read(4) + + # Not checking version for now + self.is_replaying = True + + while self.is_replaying: + try: + record = ubjson.load(self._fp_replay) + except Exception as e: + self.is_replaying = False + self._fp_replay.close() + self._fp_replay = None + print(f"Error reading record: {e}") + raise ValueError("Recording ended unexpectedly.") + self._time = record["timestamp"] + opCode = OpCodes(record["opCode"]) + if opCode == OpCodes.TERMINATE: + self.is_replaying = False + _record_switch_case(self.ctx, opCode, record.get("data", None)) + + yield record + + self._fp_replay.close() + self._fp_replay = None + + def time(self): + if self.is_replaying: + return self._time + return time.monotonic_ns() + + def write(self, opCode: OpCodes, data: JsonType) -> None: + if not self.record(): + raise RuntimeError("Cannot write: Not currently recording.") + timestamp = self.time() + if data is None: + ubjson.dump({"timestamp": timestamp, "opCode": opCode.value}, self._fp_record) + else: + ubjson.dump({"timestamp": timestamp, "opCode": opCode.value, "data": data}, self._fp_record) + + + def component(self, struct: Dict[str, Type[_T]]) -> Callable[[Type[_T]], Type[_T]]: + return component(self.ctx, struct) + + def get_component(self, name: str) -> Type[_T]: + if name not in self._components: + raise KeyError(f"Component {name} not found.") + return self._components[name] + + def delete_component(self, name: str) -> None: + if name not in self._components: + raise KeyError(f"Component {name} not found.") + if self.record(): + self.write(OpCodes.COMPONENT_DELETE, {"name": name}) + del self._components[name] + + def component_iter(self) -> Iterator[str]: + return self._components.keys() + + def add_component(self, name: str, obj: Type[_T]) -> None: + if name in self._components: + raise ValueError(f"Component {name} already exists.") + self._components[name] = obj + + def is_component_registered(self, key: Tuple[str, str]) -> bool: + return key in self._component_registry + + def unregister_component(self, key: Tuple[str, str]) -> None: + if key not in self._component_registry: + raise KeyError(f"Component {key} not found.") + if self.record(): + self.write(OpCodes.COMPONENT_UNREGISTER, {"key": key}) + del self._component_registry[key] \ No newline at end of file diff --git a/gamms/SensorEngine/sensor_engine.py b/gamms/SensorEngine/sensor_engine.py index e3a91bd..998ff5a 100644 --- a/gamms/SensorEngine/sensor_engine.py +++ b/gamms/SensorEngine/sensor_engine.py @@ -7,6 +7,7 @@ _T = TypeVar('_T') + class NeighborSensor(ISensor): def __init__(self, ctx, sensor_id, sensor_type, nodes, edges): self.sensor_id = sensor_id @@ -56,6 +57,10 @@ def __init__(self, ctx, sensor_id, sensor_type, nodes, sensor_range: float, fov: self.node_ids = list(self.nodes.keys()) self._positions = np.array([[self.nodes[nid].x, self.nodes[nid].y] for nid in self.node_ids]) self._owner = None + + @property + def type(self) -> SensorType: + return self._type @property def data(self) -> Dict[str, Any]: @@ -133,6 +138,10 @@ def __init__( self.orientation = orientation self._owner = owner self._data = {} + + @property + def type(self) -> SensorType: + return self._type @property def data(self) -> Dict[str, Any]: diff --git a/gamms/__init__.py b/gamms/__init__.py index e128282..94887a5 100644 --- a/gamms/__init__.py +++ b/gamms/__init__.py @@ -2,7 +2,7 @@ import gamms.SensorEngine.sensor_engine as sensor import gamms.GraphEngine.graph_engine as graph import gamms.VisualizationEngine as visual -from gamms.recorder import Recorder +from gamms.Recorder.recorder import Recorder from gamms.context import Context from enum import Enum diff --git a/gamms/typing/opcodes.py b/gamms/typing/opcodes.py index 3b56e8b..576aaa7 100644 --- a/gamms/typing/opcodes.py +++ b/gamms/typing/opcodes.py @@ -7,6 +7,11 @@ class OpCodes(Enum): AGENT_DELETE = 0x01000001 AGENT_CURRENT_NODE = 0x01100000 AGENT_PREV_NODE = 0x01100001 + COMPONENT_REGISTER = 0x02000000 + COMPONENT_CREATE = 0x02000001 + COMPONENT_UPDATE = 0x02000002 + COMPONENT_DELETE = 0x02000003 + COMPONENT_UNREGISTER = 0x02000004 MAGIC_NUMBER = 0x4D4D4752.to_bytes(4, 'big') VERSION = 0x00000001.to_bytes(4, 'big') \ No newline at end of file diff --git a/gamms/typing/recorder.py b/gamms/typing/recorder.py index 87d4aa5..a651153 100644 --- a/gamms/typing/recorder.py +++ b/gamms/typing/recorder.py @@ -1,7 +1,8 @@ -from typing import List, Union, Iterator, Dict +from typing import List, Union, Iterator, Dict, Type, TypeVar, Callable, BinaryIO, Tuple from abc import ABC, abstractmethod JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]] +_T = TypeVar('_T') class IRecorder(ABC): @abstractmethod @@ -30,13 +31,13 @@ def pause(self) -> None: """ pass @abstractmethod - def play(self, path: str) -> None: + def play(self, path: Union[str, BinaryIO]) -> None: """ Resume recording if paused. If not started or stopped, give warning. """ pass @abstractmethod - def replay(self, path: str) -> Iterator: + def replay(self, path: Union[str, BinaryIO]) -> Iterator: """ Checks validity of the file and output an iterator. """ @@ -52,4 +53,58 @@ def write(self, opCode, data) -> None: """ Write to record buffer if recording. If not recording raise error as it should not happen. """ + pass + + @abstractmethod + def component(self, struct: Dict[str, Type[_T]]) -> Callable[[Type[_T]], Type[_T]]: + """ + Decorator to add a component to the recorder. + """ + pass + + @abstractmethod + def get_component(self, name: str) -> Type[_T]: + """ + Get the component from the name. + Raise key error if not found. + """ + pass + + @abstractmethod + def delete_component(self, name: str) -> None: + """ + Delete the component from the name. + Raise key error if not found. + """ + pass + + @abstractmethod + def component_iter(self) -> Iterator[str]: + """ + Iterator for the component names. + """ + pass + + @abstractmethod + def add_component(self, name: str, obj: Type[_T]) -> None: + """ + Add a component to the recorder. + Raise value error if already exists. + """ + pass + + @abstractmethod + def is_component_registered(self, key: Tuple[str, str]) -> bool: + """ + Check if the component is registered. + Key is (module_name, qualname) + """ + pass + + @abstractmethod + def unregister_component(self, key: Tuple[str, str]) -> None: + """ + Unregister the component. + Key is (module_name, qualname) + """ pass \ No newline at end of file diff --git a/tests/recorder_test.py b/tests/recorder_test.py new file mode 100644 index 0000000..5656757 --- /dev/null +++ b/tests/recorder_test.py @@ -0,0 +1,104 @@ +import unittest +from gamms.typing.opcodes import OpCodes +import gamms +import io + +class RecorderTest(unittest.TestCase): + def setUp(self): + self.ctx = gamms.create_context(vis_engine=gamms.visual.Engine.NO_VIS) + # Manually create a grid graph + for i in range(25): + self.ctx.graph.graph.add_node({'id': i, 'x': i % 5, 'y': i // 5}) + + for i in range(25): + for j in range(25): + if i == j + 1 or i == j - 1 or i == j + 5 or i == j - 5: + self.ctx.graph.graph.add_edge( + {'id': i * 25 + j, 'source': i, 'target': j, 'length': 1} + ) + + # Create in memory file for recording + self.record_fp = io.BytesIO() + self.ctx.record.start(self.record_fp) + + # Create agent at node 0 + self.ctx.agent.create_agent('agent_0', start_node_id=0) + # Create agent at node 24 + self.ctx.agent.create_agent('agent_1', start_node_id=24) + + def test_record(self): + self.assertEqual(self.ctx.record.record(), True) + # Create a recorded component + @self.ctx.record.component(struct={'x': int, 'y': int}) + class TestComponent: + def __init__(self): + self.x = 0 + self.y = 0 + # Create a component + comp = TestComponent(name='test') + # Check if the component values are correct + self.assertEqual(comp.x, 0) + self.assertEqual(comp.y, 0) + self.assertEqual(comp.name, 'test') + comp.x = 1 + comp.y = 2 + + # Check if the component values are correct + self.assertEqual(comp.x, 1) + self.assertEqual(comp.y, 2) + + # Move agent_0 to node 1 + self.ctx.agent.get_agent('agent_0').current_node_id = 1 + # Move agent_1 to node 23 + self.ctx.agent.get_agent('agent_1').current_node_id = 23 + # Simulate + self.ctx.visual.simulate() + # Check if the agents are at the correct nodes + self.assertEqual(self.ctx.agent.get_agent('agent_0').current_node_id, 1) + self.assertEqual(self.ctx.agent.get_agent('agent_1').current_node_id, 23) + + # Copy the recording to a new file + self.record_fp.seek(0) + self.fp_replay = io.BytesIO(self.record_fp.read()) + self.ctx.record.stop() + + # Remove agent_0 and agent_1 + self.ctx.agent.delete_agent('agent_0') + self.ctx.agent.delete_agent('agent_1') + # Check if the agents are removed + self.assertRaises(ValueError, self.ctx.agent.get_agent, 'agent_0') + self.assertRaises(ValueError, self.ctx.agent.get_agent, 'agent_1') + + # Remove the component + self.ctx.record.delete_component('test') + cls_key = (TestComponent.__module__, TestComponent.__qualname__) + self.ctx.record.unregister_component(cls_key) + # Check if the component is removed + self.assertEqual(self.ctx.record.is_component_registered(cls_key), False) + + del TestComponent + + # Replay the recording + try: + for _ in self.ctx.record.replay(self.fp_replay): + pass + except ValueError: + pass + + # Check if the agents are at the correct nodes + self.assertEqual(self.ctx.agent.get_agent('agent_0').current_node_id, 1) + self.assertEqual(self.ctx.agent.get_agent('agent_1').current_node_id, 23) + + # Check if the component is registered + self.assertEqual(self.ctx.record.is_component_registered(cls_key), True) + # Check if the component values are correct + comp = self.ctx.record.get_component('test') + self.assertEqual(comp.x, 1) + self.assertEqual(comp.y, 2) + + + def tearDown(self): + self.ctx.terminate() + +if __name__ == '__main__': + unittest.main() \ No newline at end of file