Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 70 additions & 51 deletions startle/args.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal

Expand Down Expand Up @@ -70,6 +71,15 @@ def _args(self) -> list[Arg]:
seen.add(id(arg))
return unique_args

@property
def _children(self) -> Iterable["Arg"]:
"""
Yield all child Args instances. Only relevant when parsing recursively.
"""
for arg in self._args:
if arg.args:
yield arg

@staticmethod
def _is_name(value: str) -> str | Literal[False]:
"""
Expand Down Expand Up @@ -101,6 +111,20 @@ def _is_combined_short_names(value: str) -> str | Literal[False]:
return value
return False

def _find_arg_by_name(self, name: str) -> Arg | None:
"""
Find an argument by its name (short or long) among self or the children.
Returns the Arg if found, otherwise None.
"""
if name in self._name2idx:
return self._named_args[self._name2idx[name]]
for child in self._children:
assert child.args is not None, "Programming error!"
result = child.args._find_arg_by_name(name)
if result is not None:
return result
return None

def add(self, arg: Arg):
"""
Add an argument to the parser.
Expand Down Expand Up @@ -185,9 +209,9 @@ def _parse_combined_short_names(
if name == "?":
self.print_help()
raise SystemExit(0)
if name not in self._name2idx:
opt = self._find_arg_by_name(name)
if opt is None:
raise ParserOptionError(f"Unexpected option `{name}`!")
opt = self._named_args[self._name2idx[name]]
if opt.is_parsed and not opt.is_nary:
raise ParserOptionError(f"Option `{opt.name}` is multiply given!")

Expand Down Expand Up @@ -244,6 +268,14 @@ def _parse_named(
if name in ["help", "?"]:
self.print_help()
raise SystemExit(0)

for child in self._children:
try:
assert child.args is not None, "Programming error!"
return child.args._parse_named(name, args, state)
except ParserOptionError:
pass

if "=" in name:
return self._parse_equals_syntax(name, state)
normal_name = name.replace("_", "-")
Expand Down Expand Up @@ -333,44 +365,46 @@ def _parse_positional(self, args: list[str], state: _ParsingState) -> _ParsingSt
state.positional_idx += 1
return state

def _maybe_parse_children(self, args: list[str]) -> list[str]:
"""
Parse child Args, if any.
This method is only relevant when recurse=True is used in start() or parse().

Returns:
Remaining args after parsing child Args.
"""
remaining_args = args.copy()
for arg in self._args:
if child_args := arg.args:
try:
child_args.parse(remaining_args)
except ParserOptionError as e:
estr = str(e)
if estr.startswith("Required option") and estr.endswith(
" is not provided!"
):
# this is allowed if arg has a default value
if not arg.required:
arg._value = arg.default # type: ignore
arg._parsed = True # type: ignore
continue
# note that we do not consume any args, even partially
raise e

assert child_args._var_args is not None, "Programming error!"
remaining_args: list[str] = child_args._var_args.value or []
def _check_completion(self) -> None:
for child in self._children:
assert child.args is not None, "Programming error!"
try:
child.args._check_completion()

# construct the actual object
init_args, init_kwargs = child_args.make_func_args()
arg._value = arg.type_(*init_args, **init_kwargs) # type: ignore
arg._parsed = True # type: ignore
init_args, init_kwargs = child.args.make_func_args()
child._value = child.type_(*init_args, **init_kwargs) # type: ignore
child._parsed = True # type: ignore
except ParserOptionError as e:
estr = str(e)
if estr.startswith("Required option") and estr.endswith(
" is not provided!"
):
# this is allowed if arg has a default value
if not child.required:
child._value = child.default # type: ignore
child._parsed = True # type: ignore
continue
raise e

return remaining_args
# check if all required arguments are given, assign defaults otherwise
for arg in self._positional_args + self._named_args:
if not arg.is_parsed:
if arg.required:
if arg.is_named:
# if a positional arg is also named, prefer this type of error message
raise ParserOptionError(
f"Required option `{arg.name}` is not provided!"
)
else:
raise ParserOptionError(
f"Required positional argument <{arg.name.long}> is not provided!"
)
else:
arg._value = arg.default # type: ignore
arg._parsed = True # type: ignore

def _parse(self, args: list[str]):
args = self._maybe_parse_children(args)
state = _ParsingState()

while state.idx < len(args):
Expand All @@ -397,22 +431,7 @@ def _parse(self, args: list[str]):
# this must be a positional argument
state = self._parse_positional(args, state)

# check if all required arguments are given, assign defaults otherwise
for arg in self._positional_args + self._named_args:
if not arg.is_parsed:
if arg.required:
if arg.is_named:
# if a positional arg is also named, prefer this type of error message
raise ParserOptionError(
f"Required option `{arg.name}` is not provided!"
)
else:
raise ParserOptionError(
f"Required positional argument <{arg.name.long}> is not provided!"
)
else:
arg._value = arg.default # type: ignore
arg._parsed = True # type: ignore
self._check_completion()

def make_func_args(self) -> tuple[list[Any], dict[str, Any]]:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_help/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Any, Callable

from rich.console import Console
from startle._inspect.make_args import make_args_from_class, make_args_from_func
Expand All @@ -15,7 +15,7 @@ def remove_trailing_spaces(text: str) -> str:


def check_help_from_func(
f: Callable, program_name: str, expected: str, recurse: bool = False
f: Callable[..., Any], program_name: str, expected: str, recurse: bool = False
):
console = Console(width=120, highlight=False, force_terminal=True)
with console.capture() as capture:
Expand Down
102 changes: 86 additions & 16 deletions tests/test_recursive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from dataclasses import dataclass, field
from typing import Callable, Literal, TypedDict
from typing import Any, Callable, Literal, TypedDict

from pytest import mark, raises
from startle.error import ParserConfigError, ParserOptionError
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_recursive_w_defaults(
kind: Literal["single", "pair"] | None,
count: int | None,
) -> None:
cli_args = []
cli_args: list[str] = []
config_kwargs = {}
if sides is not None:
cli_args += [sides_opt, str(sides)]
Expand All @@ -76,7 +76,7 @@ def test_recursive_w_defaults(
if count is not None:
cli_args += [count_opt, str(count)]

expected_cfg = DieConfig(**config_kwargs)
expected_cfg = DieConfig(**config_kwargs) # type: ignore[arg-type]
expected_count = count if count is not None else 1
check_args(throw_dice, cli_args, [expected_cfg, expected_count], {}, recurse=True)
with raises(
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_recursive_w_required(
func = throw_dice2
else:
func = throw_dice2_td
cli_args = []
cli_args: list[str] = []
config_kwargs = {}
if sides is not None:
cli_args += [sides_opt, str(sides)]
Expand All @@ -182,10 +182,10 @@ def test_recursive_w_required(
):
check_args(func, cli_args, [], {}, recurse=True)
else:
expected_cfg = (
DieConfig2(**config_kwargs) if cls is DieConfig2 else {**config_kwargs}
expected_cfg: DieConfig2 | dict[str, Any] = (
DieConfig2(**config_kwargs) if cls is DieConfig2 else {**config_kwargs} # type: ignore[arg-type]
)
expected_count = count if count is not None else 1
expected_count = count
check_args(func, cli_args, [expected_cfg, expected_count], {}, recurse=True)


Expand Down Expand Up @@ -250,14 +250,15 @@ def test_recursive_w_inner_required() -> None:
# we will fail to parse and use the default for cfg.
# But because --sides is then not consumed by the child parser,
# it will surface up to the parent as an unexpected option.
with raises(ParserOptionError, match="Unexpected option `sides`!"):
check_args(
throw_dice3,
["--count", "2", "--sides", "4"],
[2, DieConfig2(sides=6, kind="single")],
{},
recurse=True,
)
# NOTE: "sides" is partially consumed now.
# with raises(ParserOptionError, match="Unexpected option `sides`!"):
# check_args(
# throw_dice3,
# ["--count", "2", "--sides", "4"],
# [2, DieConfig2(sides=6, kind="single")],
# {},
# recurse=True,
# )


class ConfigWithVarArgs:
Expand Down Expand Up @@ -507,7 +508,7 @@ def fuse2td(cfg: FusionConfig2TD) -> None:


@mark.parametrize("fuse", [fuse1, fuse2, fuse2td])
def test_recursive_dataclass_help(fuse: Callable) -> None:
def test_recursive_dataclass_help(fuse: Callable[..., Any]) -> None:
if fuse is fuse1:
expected = FusionConfig(
left_path="monster1.dat",
Expand Down Expand Up @@ -669,3 +670,72 @@ def test_recursive_dataclass_non_class() -> None:
match="Cannot recurse into parameter `io_paths` of non-class type `IOPaths2 | tuple[str, str]` in `fuse4()`!",
):
check_help_from_func(fuse4, "fuse.py", "", recurse=True)


@dataclass(kw_only=True)
class AppleConfig:
"""
Configuration for apple.

Attributes:
color: The color of the apple.
heavy: Whether the apple is heavy.
"""

color: str = "red"
heavy: bool = False


@dataclass(kw_only=True)
class BananaConfig:
"""
Configuration for banana.

Attributes:
length: The length of the banana.
ripe: Whether the banana is ripe.
"""

length: float = 6.0
ripe: bool = False


def make_fruit_salad(
apple_cfg: AppleConfig,
banana_cfg: BananaConfig,
servings: int = 1,
) -> None:
"""
Make a fruit salad.

Args:
apple_cfg: Configuration for the apple.
banana_cfg: Configuration for the banana.
servings: Number of servings.
"""
pass


@mark.parametrize(
"cli_args",
[
["--color", "green", "--heavy", "--length", "7.5", "--ripe", "--servings", "3"],
["--color", "green", "--length", "7.5", "-h", "-r", "-s", "3"],
["--color", "green", "--length", "7.5", "-hrs", "3"],
["--color", "green", "--length", "7.5", "-hrs=3"],
["--color", "green", "--length", "7.5", "-rhs", "3"],
["--color", "green", "--length", "7.5", "-rhs=3"],
],
)
def test_combined_short_flags(cli_args: list[str]) -> None:
check_args(
make_fruit_salad,
cli_args,
[
AppleConfig(color="green", heavy=True),
BananaConfig(length=7.5, ripe=True),
3,
],
{},
recurse=True,
)