Skip to content
Merged

Py310 #121

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
---
name: Python tests

on:
Expand All @@ -12,7 +13,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
python-version: ['3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand Down
14 changes: 6 additions & 8 deletions engibench/constraint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Constraints for parameters of Problem classes."""

from __future__ import annotations

from collections.abc import Callable, Iterable
import dataclasses
from dataclasses import dataclass
Expand Down Expand Up @@ -48,15 +46,15 @@ class Constraint:
criticality: Criticality = Criticality.Error
"""Criticality of a violation of the constraint."""

def category(self, category: Category) -> Constraint:
def category(self, category: Category) -> "Constraint":
"""Return a copy of the constraint which has the specified category added."""
return Constraint(check=self.check, criticality=self.criticality, categories=self.categories | category)

def warning(self) -> Constraint:
def warning(self) -> "Constraint":
"""Return a copy of the constraint with the criticality level set to "warning"."""
return Constraint(check=self.check, criticality=Criticality.Warning, categories=self.categories)

def check_dict(self, parameter_args: dict[str, Any]) -> Violation | None:
def check_dict(self, parameter_args: dict[str, Any]) -> "Violation | None":
"""Check for a violation of the given constraint for the given parameters."""
# We first inspect the arguments of check callback:
sig = inspect.signature(self.check)
Expand All @@ -82,7 +80,7 @@ def check_dict(self, parameter_args: dict[str, Any]) -> Violation | None:
return Violation(self, str(e))
return None

def check_value(self, value: Any) -> Violation | None:
def check_value(self, value: Any) -> "Violation | None":
"""Check for a violation for the given single positional value."""
try:
self.check(value)
Expand Down Expand Up @@ -165,13 +163,13 @@ def __init__(self, violations: list[Violation], n_constraints: int) -> None:
self.violations = violations
self.n_constraints = n_constraints

def by_category(self, category: Category) -> Violations:
def by_category(self, category: Category) -> "Violations":
"""Filter the violations by the category of the constraint causing the violation."""
return Violations(
[violation for violation in self.violations if category in violation.constraint.categories], self.n_constraints
)

def by_criticality(self, criticality: Criticality) -> Violations:
def by_criticality(self, criticality: Criticality) -> "Violations":
"""Filter the violations by criticality."""
return Violations(
[violation for violation in self.violations if violation.constraint.criticality == criticality],
Expand Down
11 changes: 3 additions & 8 deletions engibench/core.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
"""Core API for Problem and other base classes."""

from __future__ import annotations

from collections.abc import Sequence
import dataclasses
from enum import auto
from enum import Enum
from typing import Any, Generic, TYPE_CHECKING, TypeVar
from typing import Any, Generic, TypeVar

from datasets import Dataset
from datasets import load_dataset
from gymnasium import spaces
import numpy as np
import numpy.typing as npt

from engibench import constraint

if TYPE_CHECKING:
from collections.abc import Sequence

from gymnasium import spaces

DesignType = TypeVar("DesignType")


Expand Down
2 changes: 0 additions & 2 deletions engibench/problems/airfoil/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
+-+-+-+-+-+-+-+-+-+
"""

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import field
import os
Expand Down
2 changes: 0 additions & 2 deletions engibench/problems/beams2d/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
This code has been adapted from the Python implementation by Niels Aage and Villads Egede Johansen: https://github.com/arjendeetman/TopOpt-MMA-Python
"""

from __future__ import annotations

import dataclasses
from typing import Any, overload

Expand Down
2 changes: 0 additions & 2 deletions engibench/problems/beams2d/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""Beams 2D problem."""

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from dataclasses import field
Expand Down
2 changes: 0 additions & 2 deletions engibench/problems/heatconduction2d/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
The problem is solved using the dolfin-adjoint software within a Docker container.
"""

from __future__ import annotations

from dataclasses import dataclass
import os
import subprocess
Expand Down
4 changes: 1 addition & 3 deletions engibench/problems/heatconduction3d/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
The problem is solved using the dolfin-adjoint software within a Docker container.
"""

from __future__ import annotations

from dataclasses import dataclass
import os
import subprocess
Expand Down Expand Up @@ -292,7 +290,7 @@ def render(self, design: npt.NDArray, *, open_window: bool = False) -> Any:
] # Side edges

for edge in cube_edges:
ax.plot(*zip(*cube_vertices[list(edge)]), color="red", linewidth=2)
ax.plot(*zip(*cube_vertices[list(edge)], strict=True), color="red", linewidth=2)

if open_window:
plt.show()
Expand Down
3 changes: 2 additions & 1 deletion engibench/problems/photonics2d/dataset_slurm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
"""

from argparse import ArgumentParser
from collections.abc import Callable
from itertools import product
import os
import pickle
import shutil
import time
from typing import Any, Callable
from typing import Any

import numpy as np

Expand Down
8 changes: 2 additions & 6 deletions engibench/problems/photonics2d/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
Author: Mark Fuge @markfuge
"""

from __future__ import annotations

from dataclasses import dataclass

# Need os import for makedirs for saving plots
import os
import pprint
from typing import Annotated, Any, ClassVar, TYPE_CHECKING
from typing import Annotated, Any, ClassVar

# Importing autograd since the ceviche library uses it for automatic differentiation of the FDFD solver
import autograd.numpy as npa
from autograd.numpy.numpy_boxes import ArrayBox

# Import ArrayBox type for checking
import ceviche
Expand Down Expand Up @@ -47,9 +46,6 @@
from engibench.problems.photonics2d.backend import poly_ramp
from engibench.problems.photonics2d.backend import wavelength_to_frequency

if TYPE_CHECKING:
from autograd.numpy.numpy_boxes import ArrayBox


class Photonics2D(Problem[npt.NDArray]):
r"""Photonic Inverse Design 2D Problem (Wavelength Demultiplexer).
Expand Down
2 changes: 0 additions & 2 deletions engibench/problems/power_electronics/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Set up the configuration for the Power Electronics problem."""
# ruff: noqa: N806, N815 # Upper case

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import field
import os
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@

# ruff: noqa: N806 # Upper case variables

from __future__ import annotations

from typing import TYPE_CHECKING

import networkx as nx

from engibench.problems.power_electronics.utils.config import Config
from engibench.problems.power_electronics.utils.constants import COLOR_DICT
from engibench.problems.power_electronics.utils.constants import COMPONENTS

if TYPE_CHECKING:
from engibench.problems.power_electronics.utils.config import Config


def parse_topology(config: Config) -> tuple[Config, str, dict[str, list[int]], nx.Graph]:
"""Parse the topology from config.original_netlist_path. It does NOT change config.
Expand Down
2 changes: 0 additions & 2 deletions engibench/problems/power_electronics/utils/ngspice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""NgSpice wrapper for cross-platform support."""

from __future__ import annotations

import os
import platform
import re
Expand Down
2 changes: 0 additions & 2 deletions engibench/problems/power_electronics/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""Power Electronics problem."""

from __future__ import annotations

import os
from typing import Any, NoReturn

Expand Down
2 changes: 0 additions & 2 deletions engibench/problems/thermoelastic2d/model/fea_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""This module contains the Python implementation of the thermoelastic 2D problem."""

from __future__ import annotations

from math import ceil
from math import hypot
import time
Expand Down
7 changes: 1 addition & 6 deletions engibench/problems/thermoelastic2d/model/mma_subroutine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
"""This module contains the MMA subroutine used in the thermoelastic2d problem."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from mmapy import mmasub as external_mmasub
import numpy as np

if TYPE_CHECKING:
from numpy.typing import NDArray
from numpy.typing import NDArray

RESIDUAL_MAX_VAL = 0.9
ITERATION_MAX = 500
Expand Down
8 changes: 1 addition & 7 deletions engibench/problems/thermoelastic2d/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
"""Utility functions for the thermoelastic2d problem."""

from __future__ import annotations

from typing import TYPE_CHECKING

from matplotlib import colors
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt

if TYPE_CHECKING:
from matplotlib.figure import Figure


def get_res_bounds(x_res: int, y_res: int) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray]:
"""Generates the indices corresponding to the left, top, right, and bottom elements in the domain.
Expand Down
2 changes: 0 additions & 2 deletions engibench/problems/thermoelastic2d/v0.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Thermo Elastic 2D Problem."""

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import field
from typing import Annotated, Any, ClassVar
Expand Down
47 changes: 21 additions & 26 deletions engibench/utils/container.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
"""Abstraction over container runtimes."""

from __future__ import annotations

from collections.abc import Sequence
import os
import subprocess
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Sequence


def pull(image: str) -> None:
Expand Down Expand Up @@ -51,26 +46,6 @@ def run(
raise RuntimeError(msg) from e


def runtime() -> type[ContainerRuntime] | None:
"""Determine the container runtime to use according to the environment variable `CONTAINER_RUNTIME`.

If not set, check for availability.

Returns:
Class object of the first available container runtime or the container runtime selected by the
`CONTAINER_RUNTIME` environment variable if set.
"""
runtimes_by_name = {rt.name: rt for rt in RUNTIMES}
rt_name = os.environ.get("CONTAINER_RUNTIME")
rt = runtimes_by_name.get(rt_name) if rt_name is not None else None
if rt is not None:
return rt
for rt in RUNTIMES:
if rt.is_available():
return rt
return None


class ContainerRuntime:
"""Abstraction over container runtimes."""

Expand Down Expand Up @@ -125,6 +100,26 @@ def run(
raise NotImplementedError("Must be implemented by a subclass")


def runtime() -> type[ContainerRuntime] | None:
"""Determine the container runtime to use according to the environment variable `CONTAINER_RUNTIME`.

If not set, check for availability.

Returns:
Class object of the first available container runtime or the container runtime selected by the
`CONTAINER_RUNTIME` environment variable if set.
"""
runtimes_by_name = {rt.name: rt for rt in RUNTIMES}
rt_name = os.environ.get("CONTAINER_RUNTIME")
rt = runtimes_by_name.get(rt_name) if rt_name is not None else None
if rt is not None:
return rt
for rt in RUNTIMES:
if rt.is_available():
return rt
return None


class Docker(ContainerRuntime):
"""Docker 🐋 runtime."""

Expand Down
12 changes: 3 additions & 9 deletions engibench/utils/slurm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Slurm executor for parameter space discovery."""

from __future__ import annotations

from argparse import ArgumentParser
from collections.abc import Callable, Iterable, Sequence
from dataclasses import asdict
from dataclasses import dataclass
from dataclasses import field
Expand All @@ -14,18 +13,13 @@
import subprocess
import sys
import tempfile
from typing import Any, Generic, TYPE_CHECKING, TypeVar
from typing import Any, Generic, TypeVar

from numpy import typing as npt

from engibench.core import OptiStep
from engibench.core import Problem

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Sequence

import numpy.typing as npt


@dataclass
class Args:
Expand Down Expand Up @@ -73,7 +67,7 @@ def serialize(self) -> dict[str, Any]:
}

@classmethod
def deserialize(cls, serialized_job: dict[str, Any]) -> Job:
def deserialize(cls, serialized_job: dict[str, Any]) -> "Job":
"""Deserialize a job object from an other python process."""
design_factory = serialized_job["design_factory"]
return cls(
Expand Down
Loading
Loading