Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/upstage_des/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from simpy import Environment as SimpyEnv
from simpy import Event as SimEvent
from simpy.core import SimTime

from upstage_des.geography import INTERSECTION_LOCATION_CALLABLE, EarthProtocol
from upstage_des.units.convert import STANDARD_TIMES, TIME_ALTERNATES, unit_convert
Expand Down Expand Up @@ -167,7 +168,7 @@ class RulesError(UpstageError):
"""Raised by the user when a simulation rule is violated."""


class MockEnvironment:
class MockEnvironment(SimpyEnv):
"""A fake environment that holds the ``now`` property and all-caps attributes."""

def __init__(self, now: float):
Expand All @@ -176,7 +177,16 @@ def __init__(self, now: float):
Args:
now (float): The time the environment is at.
"""
self.now = now
super().__init__(initial_time=now)

@property
def now(self) -> SimTime:
"""The current simulation time."""
return self._now

@now.setter
def now(self, value: SimTime) -> None:
self._now = value

@classmethod
def mock(cls, env: Union[SimpyEnv, "MockEnvironment"]) -> "MockEnvironment":
Expand All @@ -196,11 +206,11 @@ def mock(cls, env: Union[SimpyEnv, "MockEnvironment"]) -> "MockEnvironment":
return mock_env

@classmethod
def run(cls, until: float | int) -> None:
def run(cls, until: SimTime | SimEvent | None = None) -> Any | None:
"""Method stub for playing nice with rehearsal.

Args:
until (float | int): Placeholder
until (SimTime | SimEvent | None): Placeholder
"""
raise UpstageError("You tried to use `run` on a mock environment")

Expand Down Expand Up @@ -458,7 +468,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._new_env: MockEnvironment | None = None
super().__init__(*args, **kwargs)

@property # type: ignore [override]
@property
def env(self) -> SimpyEnv | MockEnvironment:
"""Get the relevant environment.

Expand Down
8 changes: 4 additions & 4 deletions src/upstage_des/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from simpy.resources.resource import Release, Request
from simpy.resources.store import StoreGet, StorePut

from .base import SimulationError, UpstageBase, UpstageError
from .base import MockEnvironment, SimulationError, UpstageBase, UpstageError
from .constants import PLANNING_FACTOR_OBJECT
from .units import unit_convert

Expand Down Expand Up @@ -417,7 +417,7 @@ def as_event(self) -> SIM.Event:
SIM.Event: typically an Any or All
"""
sub_events = [self._make_event(event) for event in self.events]
assert isinstance(self.env, SIM.Environment)
assert not isinstance(self.env, MockEnvironment)
self._simpy_event = self.simpy_equivalent(self.env, sub_events)
return self._simpy_event

Expand Down Expand Up @@ -792,7 +792,7 @@ def __init__(
# yielded on
self._payload: dict[str, Any] = {}
self._auto_reset = auto_reset
assert isinstance(self.env, SIM.Environment)
assert not isinstance(self.env, MockEnvironment)
self._event = SIM.Event(self.env)

def calculate_time_to_complete(self) -> float:
Expand Down Expand Up @@ -847,7 +847,7 @@ def get_payload(self) -> dict[str, tyAny]:

def reset(self) -> None:
"""Reset the event to allow it to be held again."""
assert isinstance(self.env, SIM.Environment)
assert not isinstance(self.env, MockEnvironment)
self._event = SIM.Event(self.env)

def cancel(self) -> None:
Expand Down
67 changes: 60 additions & 7 deletions src/upstage_des/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@
from collections.abc import Callable
from copy import deepcopy
from enum import Enum
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, cast, runtime_checkable
from typing import (
TYPE_CHECKING,
Any,
Generic,
Protocol,
TypeVar,
cast,
get_args,
get_origin,
runtime_checkable,
)

from simpy import Container, Store

Expand Down Expand Up @@ -152,12 +162,12 @@ def __init__(
else:
self._recording_functions.append((thing, name))

self._types: tuple[type, ...]
self._types: tuple[type, ...] | None

if isinstance(valid_types, type):
self._types = (valid_types,)
elif valid_types is None:
self._types = tuple()
self._types = None
else:
self._types = valid_types
self.IGNORE_LOCK: bool = False
Expand Down Expand Up @@ -218,6 +228,37 @@ def _broadcast_change(self, instance: "Actor", name: str, value: ST) -> None:
if instance._state_listener is not None:
instance._state_listener.send_change(name, value)

def _infer_state(self, instance: "Actor") -> tuple[type, ...]:
"""Infer types for the state.

This should allow isinstance(value, self._infer_state(instance))

Args:
instance (Actor): The actor the state is attached to.

Returns:
tuple[Any,...]: The state type
"""
state_class = instance._state_defs[self.name]
if hasattr(state_class, "__orig_class__"):
args = get_args(state_class.__orig_class__)
return args
return tuple()

def _test_state(self, value: Any, typing: tuple[type, ...]) -> bool:
"""Check for some basic type matches."""
correct = False
for _type in typing:
try:
correct |= isinstance(value, _type)
except TypeError as e:
if "parameterized generic" not in str(e):
raise e
origin = get_origin(_type)
if origin is not None:
correct |= isinstance(value, origin)
return correct

# NOTE: A dictionary as a descriptor doesn't work well,
# because all the operations seem to happen *after* the get
# NOTE: Lists also have the same issue that
Expand All @@ -236,7 +277,14 @@ def __set__(self, instance: "Actor", value: ST) -> None:
f"to value of {old_value}. It cannot be changed once set!"
)

if self._types and not isinstance(value, self._types):
# This is unreliable, likely until PEP 718 is in.
# Typing and valid_types is not required, though, so we can skip by it if it's None
if self._types is None:
self._types = self._infer_state(instance)
if self._types == (Any,):
self._types = ()

if self._types and not self._test_state(value, self._types):
raise TypeError(f"{value} is of type {type(value)} not of type {self._types}")

instance.__dict__[self.name] = value
Expand Down Expand Up @@ -957,8 +1005,6 @@ def __init__(
for v in valid_types:
if not isinstance(v, type) or not issubclass(v, Store | Container):
raise UpstageError(f"Bad valid type for {self}: {v}")
else:
valid_types = (Store, Container)

if default is not None and (
not isinstance(default, type) or not issubclass(default, Store | Container)
Expand Down Expand Up @@ -987,6 +1033,13 @@ def __set__(self, instance: "Actor", value: dict | Any) -> None:
"It cannot be changed once set!"
)

# This is unreliable, likely until PEP 718 is in.
# Typing and valid_types is not required, though, so we can skip it if it's None
if self._types is None:
self._types = self._infer_state(instance)
if self._types == (Any,):
self._types = (Store, Container)

if not isinstance(value, dict):
# we've been passed an actual resource, so save it and leave
if not isinstance(value, self._types):
Expand Down Expand Up @@ -1061,7 +1114,7 @@ def _make_clone(self, instance: "Actor", copy: T) -> T:
"""
base_class = type(copy)
memory: dict[str, Any] = instance.__dict__[f"_memory_for_{self.name}"]
new = base_class(instance.env, **memory) # type: ignore [arg-type]
new = base_class(instance.env, **memory)
if isinstance(copy, Store) and isinstance(new, Store):
new.items = list(copy.items)
if isinstance(copy, Container) and isinstance(new, Container):
Expand Down
3 changes: 1 addition & 2 deletions src/upstage_des/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import TYPE_CHECKING, Any, TypeVar
from warnings import warn

from simpy import Environment as SimpyEnv
from simpy import Event as SimpyEvent
from simpy import Interrupt, Process

Expand Down Expand Up @@ -615,7 +614,7 @@ def run(self, *, actor: "Actor") -> Generator[SimpyEvent, None, None]:
Generator[SimpyEvent, None, None]: Generator for SimPy event queue.
"""
self.make_decision(actor=actor)
assert isinstance(self.env, SimpyEnv)
assert not isinstance(self.env, MockEnvironment)
yield self.env.timeout(0.0)

def run_skip(self, *, actor: "Actor") -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/upstage_des/test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from upstage_des.base import (
STAGE_CONTEXT_VAR,
EnvironmentContext,
MockEnvironment,
NamedUpstageEntity,
UpstageBase,
UpstageError,
Expand All @@ -24,6 +25,7 @@
def test_context() -> None:
with EnvironmentContext() as env:
assert isinstance(env, SIM.Environment)
assert not isinstance(env, MockEnvironment)
env.run(until=3)
assert env.now == 3

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def test_model() -> None:

plane = Plane(
name="searcher",
speed=2,
fuel=200,
speed=2.0,
fuel=200.0,
fuel_burn=5.0,
location=UP.CartesianLocation(20, 10),
debug_log=True,
Expand Down
7 changes: 4 additions & 3 deletions src/upstage_des/test/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import simpy as SIM
from simpy.events import ConditionValue
from simpy.resources import base
from simpy.resources.container import ContainerGet, ContainerPut
from simpy.resources.store import StoreGet, StorePut
Expand Down Expand Up @@ -357,18 +358,18 @@ def a_process() -> SIMPY_GEN:
yield env.timeout(2)

class Thing(Actor):
result = State[dict]()
result = State[ConditionValue]()
events = State[list]()

class TheTask(Task):
def task(self, *, actor: Thing) -> TASK_GEN:
wait = Wait(3.0)
proc = env.process(a_process())
res = yield Any(wait, proc)
res: ConditionValue = yield Any(wait, proc)
actor.events = [wait, proc]
actor.result = res

t = Thing(name="Thing", result=None, events=None)
t = Thing(name="Thing", result=ConditionValue(), events=[])
task = TheTask()
task.run(actor=t)
with pytest.warns(UserWarning):
Expand Down
18 changes: 9 additions & 9 deletions src/upstage_des/test/test_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def test_motion_coordination_cli() -> None:
sensor = DummySensor(env, loc, radius=10.0)

mover_start = UP.CartesianLocation(*[8, 8, 2])
mover = RealMover(name="A Mover", loc=mover_start, speed=1, detect=True)
mover = RealMover(name="A Mover", loc=mover_start, speed=1.0, detect=True)
waypoints = [
mover_start,
UP.CartesianLocation(*[-8, 8, 2]),
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_background_motion() -> None:
sensor = DummySensor(env, loc, radius=10.0)

mover_start = UP.CartesianLocation(*[8, 8, 2])
mover = RealMover(name="A Mover", loc=mover_start, speed=1, detect=True)
mover = RealMover(name="A Mover", loc=mover_start, speed=1.0, detect=True)
waypoints = [
mover_start,
UP.CartesianLocation(*[-8, 8, 2]),
Expand Down Expand Up @@ -502,7 +502,7 @@ def test_background_rehearse() -> None:

# This is the key change relative to the above
flyer_start = UP.CartesianLocation(*[8, 8, 2])
flyer = RealMover(name="A Mover", loc=flyer_start, speed=1, detect=True)
flyer = RealMover(name="A Mover", loc=flyer_start, speed=1.0, detect=True)
waypoints = [
flyer_start,
UP.CartesianLocation(*[-8, 8, 2]),
Expand Down Expand Up @@ -531,7 +531,7 @@ def test_interrupt_clean() -> None:
sensor = DummySensor(env, loc, radius=10.0)

mover_start = UP.CartesianLocation(*[8, 8, 2])
mover = RealMover(name="A Mover", loc=mover_start, speed=1, detect=True, debug_log=True)
mover = RealMover(name="A Mover", loc=mover_start, speed=1.0, detect=True, debug_log=True)
waypoints = [
mover_start,
UP.CartesianLocation(*[-8, 8, 2]),
Expand Down Expand Up @@ -580,7 +580,7 @@ def test_undetectable_cli() -> None:
sensor = DummySensor(env, loc, radius=10.0)

mover_start = UP.CartesianLocation(*[8, 8, 2])
mover = RealMover(name="A Mover", loc=mover_start, speed=1, debug_log=True, detect=True)
mover = RealMover(name="A Mover", loc=mover_start, speed=1.0, debug_log=True, detect=True)
waypoints = [
mover_start,
UP.CartesianLocation(*[-8, 8, 2]),
Expand Down Expand Up @@ -612,7 +612,7 @@ def test_redetectable() -> None:
sensor = DummySensor(env, loc, radius=10.0)

mover_start = UP.CartesianLocation(*[8, 8, 2])
mover = RealMover(name="A Mover", loc=mover_start, speed=1, debug_log=True, detect=False)
mover = RealMover(name="A Mover", loc=mover_start, speed=1.0, debug_log=True, detect=False)
waypoints = [
mover_start,
UP.CartesianLocation(*[-8, 8, 2]),
Expand All @@ -639,7 +639,7 @@ def test_undetectable_after() -> None:
sensor = DummySensor(env, loc, radius=10.0)

mover_start = UP.CartesianLocation(*[8, 8, 2])
mover = RealMover(name="A Mover", loc=mover_start, speed=1, debug_log=True, detect=True)
mover = RealMover(name="A Mover", loc=mover_start, speed=1.0, debug_log=True, detect=True)
waypoints = [
mover_start,
UP.CartesianLocation(*[-8, 8, 2]),
Expand Down Expand Up @@ -740,7 +740,7 @@ def test_motion_coordination_gi() -> None:

t = 2
geo_mover_start = UP.GeodeticLocation(*[t, t, 4000])
geo_mover = RealGeodeticMover(name="Mover", loc=geo_mover_start, speed=1, detect=True)
geo_mover = RealGeodeticMover(name="Mover", loc=geo_mover_start, speed=1.0, detect=True)

waypoints = [
geo_mover_start,
Expand Down Expand Up @@ -850,7 +850,7 @@ def test_motion_coordination_agi() -> None:

t = 2
geo_mover_start = UP.GeodeticLocation(*[t, t, 4000])
geo_mover = RealGeodeticMover(name="Mover", loc=geo_mover_start, speed=1, detect=True)
geo_mover = RealGeodeticMover(name="Mover", loc=geo_mover_start, speed=1.0, detect=True)

waypoints = [
geo_mover_start,
Expand Down
2 changes: 1 addition & 1 deletion src/upstage_des/test/test_nucleus.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class Dummy(UP.Actor):
number = UP.State[float]()
number = UP.State[float | int]()
results = UP.State[int](default=0)


Expand Down
2 changes: 1 addition & 1 deletion src/upstage_des/test/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_windowed_get() -> None:

env.process(_placer(env, act))
env.run(until=5.0)
act.timeout = 1
act.timeout = 1.0
proc.interrupt(cause="restart")
env.run()

Expand Down
Loading
Loading