From 084cc22631c95305e4b510ced4aadf5e499641fa Mon Sep 17 00:00:00 2001 From: oir Date: Thu, 1 Jan 2026 21:53:52 -0500 Subject: [PATCH 1/5] update children parsing logic --- startle/args.py | 103 ++++++++++++++++++++++++-------- tests/test_help/_utils.py | 4 +- tests/test_recursive.py | 120 +++++++++++++++++++++++++++++++++----- 3 files changed, 184 insertions(+), 43 deletions(-) diff --git a/startle/args.py b/startle/args.py index 0a6f8ea..f74b92e 100644 --- a/startle/args.py +++ b/startle/args.py @@ -1,4 +1,5 @@ import sys +from collections.abc import Iterable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal @@ -70,6 +71,15 @@ def _args(self) -> list[Arg]: seen.add(id(arg)) return unique_args + @property + def _children(self) -> Iterable["Arg"]: + """ + Yield all child Args instances. Only relevant when parsing recursively. + """ + for arg in self._args: + if arg.args: + yield arg + @staticmethod def _is_name(value: str) -> str | Literal[False]: """ @@ -244,6 +254,14 @@ def _parse_named( if name in ["help", "?"]: self.print_help() raise SystemExit(0) + + for child in self._children: + try: + assert child.args is not None, "Programming error!" + return child.args._parse_named(name, args, state) + except ParserOptionError: + pass + if "=" in name: return self._parse_equals_syntax(name, state) normal_name = name.replace("_", "-") @@ -369,8 +387,47 @@ def _maybe_parse_children(self, args: list[str]) -> list[str]: return remaining_args + def _check_completion(self) -> None: + for child in self._children: + assert child.args is not None, "Programming error!" + try: + child.args._check_completion() + + # construct the actual object + init_args, init_kwargs = child.args.make_func_args() + child._value = child.type_(*init_args, **init_kwargs) # type: ignore + child._parsed = True # type: ignore + except ParserOptionError as e: + estr = str(e) + if estr.startswith("Required option") and estr.endswith( + " is not provided!" + ): + # this is allowed if arg has a default value + if not child.required: + child._value = child.default # type: ignore + child._parsed = True # type: ignore + continue + raise e + + # check if all required arguments are given, assign defaults otherwise + for arg in self._positional_args + self._named_args: + if not arg.is_parsed: + if arg.required: + if arg.is_named: + # if a positional arg is also named, prefer this type of error message + raise ParserOptionError( + f"Required option `{arg.name}` is not provided!" + ) + else: + raise ParserOptionError( + f"Required positional argument <{arg.name.long}> is not provided!" + ) + else: + arg._value = arg.default # type: ignore + arg._parsed = True # type: ignore + def _parse(self, args: list[str]): - args = self._maybe_parse_children(args) + # args = self._maybe_parse_children(args) state = _ParsingState() while state.idx < len(args): @@ -385,34 +442,30 @@ def _parse(self, args: list[str]): if names := self._is_combined_short_names(args[state.idx]): state = self._parse_combined_short_names(names, args, state) else: - try: - state = self._parse_named(name, args, state) - except ParserOptionError as e: - if self._var_args and str(e).startswith("Unexpected option"): - self._var_args.parse(args[state.idx]) - state.idx += 1 - else: - raise + parsed_by_child = False + # for child in self._children: + # assert child.args is not None, "Programming error!" + # if name in child.args._name2idx: + # # delegate to child Args + # state = child.args._parse_named(name, args, state) + # parsed_by_child = True + # break + if not parsed_by_child: + try: + state = self._parse_named(name, args, state) + except ParserOptionError as e: + if self._var_args and str(e).startswith( + "Unexpected option" + ): + self._var_args.parse(args[state.idx]) + state.idx += 1 + else: + raise else: # this must be a positional argument state = self._parse_positional(args, state) - # check if all required arguments are given, assign defaults otherwise - for arg in self._positional_args + self._named_args: - if not arg.is_parsed: - if arg.required: - if arg.is_named: - # if a positional arg is also named, prefer this type of error message - raise ParserOptionError( - f"Required option `{arg.name}` is not provided!" - ) - else: - raise ParserOptionError( - f"Required positional argument <{arg.name.long}> is not provided!" - ) - else: - arg._value = arg.default # type: ignore - arg._parsed = True # type: ignore + self._check_completion() def make_func_args(self) -> tuple[list[Any], dict[str, Any]]: """ diff --git a/tests/test_help/_utils.py b/tests/test_help/_utils.py index 598134e..96ad3c4 100644 --- a/tests/test_help/_utils.py +++ b/tests/test_help/_utils.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Any, Callable from rich.console import Console from startle._inspect.make_args import make_args_from_class, make_args_from_func @@ -15,7 +15,7 @@ def remove_trailing_spaces(text: str) -> str: def check_help_from_func( - f: Callable, program_name: str, expected: str, recurse: bool = False + f: Callable[..., Any], program_name: str, expected: str, recurse: bool = False ): console = Console(width=120, highlight=False, force_terminal=True) with console.capture() as capture: diff --git a/tests/test_recursive.py b/tests/test_recursive.py index 659eddb..26d2a71 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -1,6 +1,6 @@ import re from dataclasses import dataclass, field -from typing import Callable, Literal, TypedDict +from typing import Any, Callable, Literal, TypedDict from pytest import mark, raises from startle.error import ParserConfigError, ParserOptionError @@ -65,7 +65,7 @@ def test_recursive_w_defaults( kind: Literal["single", "pair"] | None, count: int | None, ) -> None: - cli_args = [] + cli_args: list[str] = [] config_kwargs = {} if sides is not None: cli_args += [sides_opt, str(sides)] @@ -76,7 +76,7 @@ def test_recursive_w_defaults( if count is not None: cli_args += [count_opt, str(count)] - expected_cfg = DieConfig(**config_kwargs) + expected_cfg = DieConfig(**config_kwargs) # type: ignore[arg-type] expected_count = count if count is not None else 1 check_args(throw_dice, cli_args, [expected_cfg, expected_count], {}, recurse=True) with raises( @@ -157,7 +157,7 @@ def test_recursive_w_required( func = throw_dice2 else: func = throw_dice2_td - cli_args = [] + cli_args: list[str] = [] config_kwargs = {} if sides is not None: cli_args += [sides_opt, str(sides)] @@ -182,10 +182,10 @@ def test_recursive_w_required( ): check_args(func, cli_args, [], {}, recurse=True) else: - expected_cfg = ( - DieConfig2(**config_kwargs) if cls is DieConfig2 else {**config_kwargs} + expected_cfg: DieConfig2 | dict[str, Any] = ( + DieConfig2(**config_kwargs) if cls is DieConfig2 else {**config_kwargs} # type: ignore[arg-type] ) - expected_count = count if count is not None else 1 + expected_count = count check_args(func, cli_args, [expected_cfg, expected_count], {}, recurse=True) @@ -250,14 +250,15 @@ def test_recursive_w_inner_required() -> None: # we will fail to parse and use the default for cfg. # But because --sides is then not consumed by the child parser, # it will surface up to the parent as an unexpected option. - with raises(ParserOptionError, match="Unexpected option `sides`!"): - check_args( - throw_dice3, - ["--count", "2", "--sides", "4"], - [2, DieConfig2(sides=6, kind="single")], - {}, - recurse=True, - ) + # NOTE: "sides" is partially consumed now. + # with raises(ParserOptionError, match="Unexpected option `sides`!"): + # check_args( + # throw_dice3, + # ["--count", "2", "--sides", "4"], + # [2, DieConfig2(sides=6, kind="single")], + # {}, + # recurse=True, + # ) class ConfigWithVarArgs: @@ -507,7 +508,7 @@ def fuse2td(cfg: FusionConfig2TD) -> None: @mark.parametrize("fuse", [fuse1, fuse2, fuse2td]) -def test_recursive_dataclass_help(fuse: Callable) -> None: +def test_recursive_dataclass_help(fuse: Callable[..., Any]) -> None: if fuse is fuse1: expected = FusionConfig( left_path="monster1.dat", @@ -669,3 +670,90 @@ def test_recursive_dataclass_non_class() -> None: match="Cannot recurse into parameter `io_paths` of non-class type `IOPaths2 | tuple[str, str]` in `fuse4()`!", ): check_help_from_func(fuse4, "fuse.py", "", recurse=True) + + +@dataclass(kw_only=True) +class AppleConfig: + """ + Configuration for apple. + + Attributes: + color: The color of the apple. + heavy: Whether the apple is heavy. + """ + + color: str = "red" + heavy: bool = False + + +@dataclass(kw_only=True) +class BananaConfig: + """ + Configuration for banana. + + Attributes: + length: The length of the banana. + ripe: Whether the banana is ripe. + """ + + length: float = 6.0 + ripe: bool = False + + +def make_fruit_salad( + apple_cfg: AppleConfig, + banana_cfg: BananaConfig, + servings: int = 1, +) -> None: + """ + Make a fruit salad. + + Args: + apple_cfg: Configuration for the apple. + banana_cfg: Configuration for the banana. + servings: Number of servings. + """ + pass + + +def test_combined_short_flags() -> None: + check_args( + make_fruit_salad, + [ + "--color", + "green", + "--heavy", + "--length", + "7.5", + "--ripe", + "--servings", + "3", + ], + [ + AppleConfig(color="green", heavy=True), + BananaConfig(length=7.5, ripe=True), + 3, + ], + {}, + recurse=True, + ) + check_args( + make_fruit_salad, + [ + "--color", + "green", + "--length", + "7.5", + "-h", + "-r", + "-s", + "3", + ], + [ + AppleConfig(color="green", heavy=True), + BananaConfig(length=7.5, ripe=True), + 3, + ], + {}, + recurse=True, + ) From 5ad1bb5fe5bf4d2e84afa714bcf49107344be872 Mon Sep 17 00:00:00 2001 From: oir Date: Fri, 2 Jan 2026 20:14:43 -0500 Subject: [PATCH 2/5] remove unused method --- startle/args.py | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/startle/args.py b/startle/args.py index f74b92e..618af05 100644 --- a/startle/args.py +++ b/startle/args.py @@ -351,42 +351,6 @@ def _parse_positional(self, args: list[str], state: _ParsingState) -> _ParsingSt state.positional_idx += 1 return state - def _maybe_parse_children(self, args: list[str]) -> list[str]: - """ - Parse child Args, if any. - This method is only relevant when recurse=True is used in start() or parse(). - - Returns: - Remaining args after parsing child Args. - """ - remaining_args = args.copy() - for arg in self._args: - if child_args := arg.args: - try: - child_args.parse(remaining_args) - except ParserOptionError as e: - estr = str(e) - if estr.startswith("Required option") and estr.endswith( - " is not provided!" - ): - # this is allowed if arg has a default value - if not arg.required: - arg._value = arg.default # type: ignore - arg._parsed = True # type: ignore - continue - # note that we do not consume any args, even partially - raise e - - assert child_args._var_args is not None, "Programming error!" - remaining_args: list[str] = child_args._var_args.value or [] - - # construct the actual object - init_args, init_kwargs = child_args.make_func_args() - arg._value = arg.type_(*init_args, **init_kwargs) # type: ignore - arg._parsed = True # type: ignore - - return remaining_args - def _check_completion(self) -> None: for child in self._children: assert child.args is not None, "Programming error!" @@ -427,7 +391,6 @@ def _check_completion(self) -> None: arg._parsed = True # type: ignore def _parse(self, args: list[str]): - # args = self._maybe_parse_children(args) state = _ParsingState() while state.idx < len(args): From 9cede3feb6b0b4cd50fd5e1b379c4464a3e31de4 Mon Sep 17 00:00:00 2001 From: oir Date: Fri, 2 Jan 2026 21:16:18 -0500 Subject: [PATCH 3/5] remove dead code --- startle/args.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/startle/args.py b/startle/args.py index 618af05..78a5ef9 100644 --- a/startle/args.py +++ b/startle/args.py @@ -405,25 +405,14 @@ def _parse(self, args: list[str]): if names := self._is_combined_short_names(args[state.idx]): state = self._parse_combined_short_names(names, args, state) else: - parsed_by_child = False - # for child in self._children: - # assert child.args is not None, "Programming error!" - # if name in child.args._name2idx: - # # delegate to child Args - # state = child.args._parse_named(name, args, state) - # parsed_by_child = True - # break - if not parsed_by_child: - try: - state = self._parse_named(name, args, state) - except ParserOptionError as e: - if self._var_args and str(e).startswith( - "Unexpected option" - ): - self._var_args.parse(args[state.idx]) - state.idx += 1 - else: - raise + try: + state = self._parse_named(name, args, state) + except ParserOptionError as e: + if self._var_args and str(e).startswith("Unexpected option"): + self._var_args.parse(args[state.idx]) + state.idx += 1 + else: + raise else: # this must be a positional argument state = self._parse_positional(args, state) From 3ac3a780d0846f582f325cfb51fca5e8dfdc243a Mon Sep 17 00:00:00 2001 From: oir Date: Fri, 2 Jan 2026 22:43:40 -0500 Subject: [PATCH 4/5] parse combined flags correctly in the recursive case --- startle/args.py | 18 ++++++++++++++++-- tests/test_recursive.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/startle/args.py b/startle/args.py index 78a5ef9..16d4163 100644 --- a/startle/args.py +++ b/startle/args.py @@ -111,6 +111,20 @@ def _is_combined_short_names(value: str) -> str | Literal[False]: return value return False + def _find_arg_by_name(self, name: str) -> Arg | None: + """ + Find an argument by its name (short or long) among self or the children. + Returns the Arg if found, otherwise None. + """ + if name in self._name2idx: + return self._named_args[self._name2idx[name]] + for child in self._children: + assert child.args is not None, "Programming error!" + result = child.args._find_arg_by_name(name) + if result is not None: + return result + return None + def add(self, arg: Arg): """ Add an argument to the parser. @@ -195,9 +209,9 @@ def _parse_combined_short_names( if name == "?": self.print_help() raise SystemExit(0) - if name not in self._name2idx: + opt = self._find_arg_by_name(name) + if opt is None: raise ParserOptionError(f"Unexpected option `{name}`!") - opt = self._named_args[self._name2idx[name]] if opt.is_parsed and not opt.is_nary: raise ParserOptionError(f"Option `{opt.name}` is multiply given!") diff --git a/tests/test_recursive.py b/tests/test_recursive.py index 26d2a71..4a628ed 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -757,3 +757,21 @@ def test_combined_short_flags() -> None: {}, recurse=True, ) + check_args( + make_fruit_salad, + [ + "--color", + "green", + "--length", + "7.5", + "-hrs", + "3", + ], + [ + AppleConfig(color="green", heavy=True), + BananaConfig(length=7.5, ripe=True), + 3, + ], + {}, + recurse=True, + ) From 7f385e3e1f6d9d58c1a43174272ab20b40ee74b9 Mon Sep 17 00:00:00 2001 From: oir Date: Sat, 3 Jan 2026 11:01:07 -0500 Subject: [PATCH 5/5] expand test --- tests/test_recursive.py | 62 +++++++++-------------------------------- 1 file changed, 13 insertions(+), 49 deletions(-) diff --git a/tests/test_recursive.py b/tests/test_recursive.py index 4a628ed..257a50d 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -716,57 +716,21 @@ def make_fruit_salad( pass -def test_combined_short_flags() -> None: - check_args( - make_fruit_salad, - [ - "--color", - "green", - "--heavy", - "--length", - "7.5", - "--ripe", - "--servings", - "3", - ], - [ - AppleConfig(color="green", heavy=True), - BananaConfig(length=7.5, ripe=True), - 3, - ], - {}, - recurse=True, - ) - check_args( - make_fruit_salad, - [ - "--color", - "green", - "--length", - "7.5", - "-h", - "-r", - "-s", - "3", - ], - [ - AppleConfig(color="green", heavy=True), - BananaConfig(length=7.5, ripe=True), - 3, - ], - {}, - recurse=True, - ) +@mark.parametrize( + "cli_args", + [ + ["--color", "green", "--heavy", "--length", "7.5", "--ripe", "--servings", "3"], + ["--color", "green", "--length", "7.5", "-h", "-r", "-s", "3"], + ["--color", "green", "--length", "7.5", "-hrs", "3"], + ["--color", "green", "--length", "7.5", "-hrs=3"], + ["--color", "green", "--length", "7.5", "-rhs", "3"], + ["--color", "green", "--length", "7.5", "-rhs=3"], + ], +) +def test_combined_short_flags(cli_args: list[str]) -> None: check_args( make_fruit_salad, - [ - "--color", - "green", - "--length", - "7.5", - "-hrs", - "3", - ], + cli_args, [ AppleConfig(color="green", heavy=True), BananaConfig(length=7.5, ripe=True),