Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ If you use EngiBench in your research, please cite the following paper:
url = {https://openreview.net/forum?id=YowD33Q89V},
urldate = {2025-10-07},
author = {Felten, Florian and Apaza, Gabriel and B\¨aunlich, Gerhard and Diniz, Cashen and Dong, Xuliang and Drake, Arthur and Habibi, Milad and Hoffman, Nathaniel J. and Keeler, Matthew and Massoudi, Soheyl and VanGessel, Francis G. and Fuge, Mark},
booktitle = {Proceedings of the 39th Conference on Neural Information Processing Systems ({NeurIPS} 2025)}
booktitle = {Proceedings of the 39th Conference on Neural Information Processing Systems ({NeurIPS} 2025)},
year = {2025},
}
```
Expand Down
30 changes: 26 additions & 4 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,38 @@ Fork EngiBench and edit the docstring in the problem's Python file. Then, pip in
### Adding a new problem

Ensure the problem is in EngiBench (or your fork). Ensure that its Python file has a properly formatted markdown docstring. Install using `pip install -e .[doc]` and add a markdown file in the [./problems/](problems/) of the repo.
Use Engibench's own `problem` directive in the docs of your new problem:
Use EngiBench's own `problem:table` and `problem:conditions` directives in the docs of your new problem:
``````md
# Your Problem

``` {problem} your_problem
``` {problem:table}
:lead: Chuck Norris @chucknorris
```

...

## Conditions

``` {problem:conditions}
```

...
``````

Here, `your_problem` must match the name of the module where your problem class is defined.
This will automatically include the docstrings of your `Problem` class as well as a table with its metadata. Then complete the [other steps](#other-steps).
**`problem:table`**: This directive extracts metadata from a problem and
inserts a table filled with the metadata.
By default, the directive will try to import the problem `engibench.problems.<problem_id>`, where `<problem_id>` is the filename (without `.md` extension) of the markdown file the directive is used.

Options (optional):
* `:problem_id:` override `<problem_id>`,
* `:lead:` Add a row "Lead" to the table, containing the specified value.
If the value ends with `@username`, a link to `https://github.com/username` will be inserted.

**`problem:conditions`**: This directive lists the conditions extracted from a problem as in the "Conditions" row produced by the `problem:table` directive. The `<problem_id>` is determined the same way as in `problem:table`.

Options:
* `:problem_id:` override `<problem_id>`,
* `:defaults:` include default values in the list of conditions

#### Other steps

Expand Down
230 changes: 208 additions & 22 deletions docs/_ext/problem_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,59 +8,94 @@
`````
"""

from collections.abc import Iterator, Sequence
from collections.abc import Iterable, Iterator, Sequence
import contextlib
import dataclasses
import importlib.abc
import importlib.machinery
import inspect
import sys
from types import ModuleType
from typing import Any
from typing import Any, ClassVar, get_type_hints, TYPE_CHECKING
import unittest.mock

from docutils import nodes
from docutils.parsers.rst import directives
from sphinx import addnodes
from sphinx.application import Sphinx
from sphinx.domains import Domain
from sphinx.util.docutils import SphinxDirective
from sphinx.util.typing import ExtensionMetadata

if TYPE_CHECKING:
from sphinx.builders import Builder
from sphinx.environment import BuildEnvironment

MODULE_WHITELIST = frozenset(["engibench"])
MODULE_EXTRA_MEMBERS = {"networkx": ["Graph"], "gymnasium": ["spaces"]}


def setup(app: Sphinx) -> None:
def setup(app: Sphinx) -> ExtensionMetadata:
"""Add extension to sphinx."""
app.add_directive("problem", ProblemDirective)
app.add_domain(ProblemDomain)
return {
"version": "0.1",
"parallel_read_safe": True,
"parallel_write_safe": True,
}


class Lead:
"""Option to specify a lead in the problem directive."""

caption = "Lead"

def __init__(self, value: str) -> None:
self.handle = None
self.name, self.handle = value.split(" @", 1) if " @" in value else (value, None)

class ProblemDirective(SphinxDirective):
required_arguments = 1
def to_node(self) -> nodes.Node:
p = nodes.Text(self.name)
if self.handle:
node = nodes.paragraph()
node += [
p,
nodes.Text(" "),
nodes.reference(refuri=f"https://github.com/{self.handle}", text="@" + self.handle),
]
return node
return p


class ProblemTableDirective(SphinxDirective):
option_spec: ClassVar[dict[str, Any]] = {"lead": Lead, "problem_id": str}

def run(self) -> list[Any]:
with mock_imports(MODULE_WHITELIST, extra_members=MODULE_EXTRA_MEMBERS):
from engibench.core import ObjectiveDirection
from engibench.utils.all_problems import BUILTIN_PROBLEMS
problem_id = self.options.get("problem_id") or problem_id_from_docname(self.env.docname)
problem = import_problem(problem_id)
ObjectiveDirection = import_objective_direction() # noqa: N806

problem_id = self.arguments[0].strip()
problem = BUILTIN_PROBLEMS[problem_id]
docstring = unindent(problem.__doc__) if problem.__doc__ is not None else ""
docstring = inspect.cleandoc(docstring)
docstring = inspect.getdoc(problem)

image = nodes.image(uri=f"../_static/img/problems/{problem_id}.png", width="450px", align="center")

objectives = [
f"{obj}: ↑" if direction == ObjectiveDirection.MAXIMIZE else f"{obj}: ↓"
for obj, direction in problem.objectives
]
conditions = [f"{f.name}: {f.default}" for f in dataclasses.fields(problem.Conditions)]
conditions = read_dataclass(problem.Conditions)

lead = self.options.get("lead")

tab_data = [
("Version", str(problem.version)),
("Design space", make_code(repr(problem.design_space))),
("Objectives", make_multiline(objectives)),
("Conditions", make_multiline(conditions)),
("Conditions", make_simple_field_list(conditions)),
("Dataset", make_link(problem.dataset_id, f"https://huggingface.co/datasets/{problem.dataset_id}")),
("Container", make_code(problem.container_id) if problem.container_id is not None else None),
("Import", make_code(f"from {problem.__module__} import {problem.__name__}")),
*([("Lead", lead.to_node())] if lead is not None else []),
]

# Very ugly hack to retain the order of children
Expand All @@ -73,15 +108,60 @@ def run(self) -> list[Any]:
sec.clear()
sec.extend(header)

return [image, make_table(tab_data), *body]
return [image, *body, make_table(tab_data)]


def problem_id_from_docname(docname: str) -> str:
_, problem_id = docname.rsplit("/", 1)
return problem_id


class ConditionsDirective(SphinxDirective):
option_spec: ClassVar[dict[str, Any]] = {"problem_id": str, "defaults": bool}

def run(self) -> list[Any]:
problem_id = self.options.get("problem_id") or problem_id_from_docname(self.env.docname)
problem = import_problem(problem_id)

conditions = read_dataclass(problem.Conditions)
return [make_simple_field_list(conditions, defaults=self.options.get("defaults", False))]


class ProblemDomain(Domain):
name = "problem"
label = "Engibench Problem"

directives: ClassVar[dict[str, SphinxDirective]] = {
"table": ProblemTableDirective,
"conditions": ConditionsDirective,
}

def resolve_any_xref( # noqa: PLR0913
self,
env: "BuildEnvironment", # noqa: ARG002
fromdocname: str, # noqa: ARG002
builder: "Builder", # noqa: ARG002
target: str, # noqa: ARG002
node: addnodes.pending_xref, # noqa: ARG002
contnode: nodes.Element, # noqa: ARG002
) -> list[tuple[str, nodes.reference]]:
return []


def import_objective_direction() -> type[Any]:
"""Import the ObjectiveDirection enum without requiring engibench dependencies."""
with mock_imports(MODULE_WHITELIST, extra_members=MODULE_EXTRA_MEMBERS):
from engibench.core import ObjectiveDirection # noqa: PLC0415

return ObjectiveDirection

def make_section(title: str, section_id: str, body: list[Any]) -> nodes.section:
sec = nodes.section(ids=[section_id])
sec += nodes.title(text=title)
for element in body:
sec += element
return sec

def import_problem(problem_id: str) -> Any:
"""Import problem metadata without requiring engibench dependencies."""
with mock_imports(MODULE_WHITELIST, extra_members=MODULE_EXTRA_MEMBERS):
from engibench.utils.all_problems import BUILTIN_PROBLEMS # noqa: PLC0415

return BUILTIN_PROBLEMS[problem_id]


def make_link(text: str, uri: str) -> nodes.paragraph:
Expand Down Expand Up @@ -118,6 +198,59 @@ def make_table(tab_data: list[tuple[str, Any]]) -> nodes.table:
return table


@dataclasses.dataclass
class Field:
name: str
type: type | None
default: Any
doc: str | None


def make_field_list(fields: list[Field]) -> nodes.Node:
node = addnodes.desc()
for f in fields:
f_node = addnodes.desc()
node.append(f_node)
signode = addnodes.desc_signature("", "")
f_node.append(signode)
signode += addnodes.desc_name(f.name, f.name)
if f.type is not None:
signode += addnodes.desc_annotation(
directives.unchanged,
"",
addnodes.desc_sig_punctuation("", ": "),
addnodes.desc_sig_space(),
nodes.Text(f.type.__name__ if isinstance(f.type, type) else str(f.type)),
)
if f.default is not dataclasses.MISSING:
signode += addnodes.desc_annotation(
directives.unchanged,
"",
addnodes.desc_sig_punctuation("", " ="),
addnodes.desc_sig_space(),
nodes.Text(f.default),
)
if f.doc is not None:
f_node.append(addnodes.desc_content("", nodes.Text(f.doc)))

return node


def make_simple_field_list(fields: list[Field], *, defaults: bool = False) -> nodes.Node:
node = nodes.bullet_list()
for f in fields:
item = nodes.list_item()
node += item
p = nodes.paragraph()
p += nodes.literal(text=f.name)
text_pieces = [f.doc, f"(default: {f.default})" if f.default is not dataclasses.MISSING and defaults else None]
if f.doc is not None:
p += nodes.Text(": " + " ".join([piece for piece in text_pieces if piece]))
item += p

return node


def unindent(docstring: str) -> str:
if not docstring:
return ""
Expand All @@ -139,12 +272,65 @@ def unindent(docstring: str) -> str:


def line_indent(line: str) -> int | None:
"""Determine the indent of a lines"""
stripped = line.lstrip()
if stripped:
return len(line) - len(stripped)
return None


def read_dataclass(c: type) -> list[Field]:
"""Read the fields of a dataclass including docstrings for attributes."""
docs = read_field_docstrings(c)
types = get_type_hints(c)
fields = dataclasses.fields(c)
return [Field(name=f.name, default=f.default, doc=docs.get(f.name), type=types.get(f.name)) for f in fields]


def read_field_docstrings(c: type) -> dict[str, str]: # noqa C903
"""Read field docstrings from a dataclass."""
src = inspect.getsource(c)
indent = ((line_indent(src) or 0) + 4) * " "

def find_line_start(src: str) -> str | None:
pos = src.find("\n" + indent)
return None if pos == -1 else src[pos + len(indent) + 1 :]

def field_name(line: str) -> tuple[str, str | None]:
try:
name, rest = line.split(": ", 1)
except ValueError:
return line, None
return (rest, name) if name.isidentifier() else (line, None)

def docstr(line: str) -> tuple[str, str | None]:
if not line.startswith('"""'):
return line, None
pos = line.find('"""', 3)
if pos == -1:
raise ValueError("Unterminated docstring found")
return line[pos + 3 :], line[3:pos]

def tokenize(src: str) -> Iterable[tuple[str, str]]:
rest: str | None = src
f_name: str | None = None
while rest:
rest = find_line_start(rest)
if rest is None:
break
rest, new_f_name = field_name(rest)
if new_f_name is not None:
f_name = new_f_name
continue
if f_name is not None:
rest, docstring = docstr(rest)
if docstring is not None:
yield f_name, docstring
f_name = None

return dict(tokenize(src))


@contextlib.contextmanager
def mock_imports(whitelist: frozenset[str], extra_members: dict[str, list[str]] | None = None) -> Iterator[None]:
"""Add an import hook just after the builtin modules hook and the frozen module hook:
Expand Down
12 changes: 7 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))

# -- Project information -----------------------------------------------------
import os
from pathlib import Path
import sys
Expand Down Expand Up @@ -52,7 +48,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns: list[str] = []
exclude_patterns: list[str] = ["README.md"]

# Napoleon settings
napoleon_use_ivar = True
Expand Down Expand Up @@ -131,3 +127,9 @@
os.path.realpath(os.path.join(os.path.dirname(__file__), "..", "tests", "tools")),
os.path.realpath(os.path.join(os.path.dirname(__file__), "utils")),
]

myst_url_schemes = {
"http": None,
"https": None,
"source": "https://github.com/IDEALLab/EngiBench/tree/main/{{path}}",
}
5 changes: 1 addition & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,8 @@ EngiOpt <https://github.com/IDEALLab/EngiOpt>
```

```{toctree}
:hidden:
:caption: Utils

utils/container
utils/slurm
utils/index
```

```{toctree}
Expand Down
Loading
Loading