Skip to content

Fix HfArgumentParser to filter out dict types from Union#39741

Open
st81 wants to merge 1 commit intohuggingface:mainfrom
st81:fix-hfargumentparser-filter-dict-
Open

Fix HfArgumentParser to filter out dict types from Union#39741
st81 wants to merge 1 commit intohuggingface:mainfrom
st81:fix-hfargumentparser-filter-dict-

Conversation

@st81
Copy link
Copy Markdown
Contributor

@st81 st81 commented Jul 28, 2025

What does this PR do?

This PR updates HfArgumentParser to filter out dict types from Union annotations. This prevents runtime errors with argparse, which does not support parsing arguments as dicts. Previous work (PR #39467) reordered Union[str, dict] to Union[dict, str] throughout the codebase, but this approach is fragile and could break if new Union[str, dict] annotations are introduced. This change ensures that dict types are always filtered out, making the parser more robust and maintainable.

Fixes #39462

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

(same as previous PR)

Copy link
Copy Markdown
Contributor

@Tavish9 Tavish9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This prevents runtime errors with argparse, which does not support parsing arguments as dicts.

Please show me a smallest example that could fail on the main branch.

could break if new Union[str, dict] annotations are introduced

this would not happen since it cannot pass the tests I added.

This change ensures that dict types are always filtered out

This is not reasonable, people cannot pass a dict to the parser anymore. Also, this change could not pass the tests :(

@st81
Copy link
Copy Markdown
Contributor Author

st81 commented Jul 29, 2025

@Tavish9 Thank you for your feedback!

Please show me a smallest example that could fail on the main branch.

For example, if we add the new argument new_arg: Optional[Union[str, dict]] = None, to the function _get_torch_dtype in modeling_utils.py,

def _get_torch_dtype(
    cls,
    torch_dtype: Optional[Union[str, torch.dtype, dict]],
    checkpoint_files: Optional[list[str]],
    config: PretrainedConfig,
    sharded_metadata: Optional[dict],
    state_dict: Optional[dict],
    weights_only: bool,
    new_arg: Optional[Union[str, dict]] = None, # <-- this is the new argument
) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:

the following code snippet will fail on the main branch.

from transformers import (
    HfArgumentParser,
    Trainer,
    TrainingArguments,
)


def main():
    parser = HfArgumentParser(TrainingArguments)
    training_args = parser.parse_args_into_dataclasses()[0]
    print("parse success")
    print(type(training_args.accelerator_config.gradient_accumulation_kwargs))


if __name__ == "__main__":
    main()
python3 test.py --accelerator_config '{"gradient_accumulation_kwargs": {"num_steps": 2}}'
test.py: error: argument --accelerator_config/--accelerator-config: invalid dict value: '{"gradient_accumulation_kwargs": {"num_steps": 2}}'

this would not happen since it cannot pass the tests I added.

You're absolutely right! The existing tests tests/utils/test_hf_argparser.py would catch this with following error message:

The error message is long so I enclose it in a details tag.
python3 -m pytest tests/utils/test_hf_argparser.py

__________________________________________________________________________ HfArgumentParserTest.test_16_cli_input_parsing __________________________________________________________________________

self = HfArgumentParser(prog='test.py', usage=None, description=None, formatter_class=<class 'argparse.ArgumentDefaultsHelpFormatter'>, conflict_handler='error', add_help=True)
action = _StoreAction(option_strings=['--accelerator_config', '--accelerator-config'], dest='accelerator_config', nargs=None, c...ccelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`.', metavar=None)
arg_string = '{"gradient_accumulation_kwargs": {"num_steps": 2}}'

    def _get_value(self, action, arg_string):
        type_func = self._registry_get('type', action.type, action.type)
        if not callable(type_func):
            msg = _('%r is not callable')
            raise ArgumentError(action, msg % type_func)
    
        # convert the value to the appropriate type
        try:
>           result = type_func(arg_string)
                     ^^^^^^^^^^^^^^^^^^^^^
E           ValueError: dictionary update sequence element #0 has length 1; 2 is required

../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2546: ValueError

During handling of the above exception, another exception occurred:

self = HfArgumentParser(prog='test.py', usage=None, description=None, formatter_class=<class 'argparse.ArgumentDefaultsHelpFormatter'>, conflict_handler='error', add_help=True)
args = ['--accelerator_config', '{"gradient_accumulation_kwargs": {"num_steps": 2}}']
namespace = Namespace(output_dir=None, overwrite_output_dir=False, do_train=False, do_eval=False, do_predict=False, eval_strategy=...se, use_liger_kernel=False, liger_kernel_config=None, eval_use_gather_object=False, average_tokens_across_devices=True)
intermixed = False

    def _parse_known_args2(self, args, namespace, intermixed):
        if args is None:
            # args default to the system args
            args = _sys.argv[1:]
        else:
            # make sure that args are mutable
            args = list(args)
    
        # default Namespace built from parser defaults
        if namespace is None:
            namespace = Namespace()
    
        # add any action defaults that aren't present
        for action in self._actions:
            if action.dest is not SUPPRESS:
                if not hasattr(namespace, action.dest):
                    if action.default is not SUPPRESS:
                        setattr(namespace, action.dest, action.default)
    
        # add any parser defaults that aren't present
        for dest in self._defaults:
            if not hasattr(namespace, dest):
                setattr(namespace, dest, self._defaults[dest])
    
        # parse the arguments and exit if there are any errors
        if self.exit_on_error:
            try:
>               namespace, args = self._parse_known_args(args, namespace, intermixed)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:1943: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2184: in _parse_known_args
    start_index = consume_optional(start_index)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2113: in consume_optional
    take_action(action, args, option_string)
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2003: in take_action
    argument_values = self._get_values(action, argument_strings)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2513: in _get_values
    value = self._get_value(action, arg_string)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = HfArgumentParser(prog='test.py', usage=None, description=None, formatter_class=<class 'argparse.ArgumentDefaultsHelpFormatter'>, conflict_handler='error', add_help=True)
action = _StoreAction(option_strings=['--accelerator_config', '--accelerator-config'], dest='accelerator_config', nargs=None, c...ccelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`.', metavar=None)
arg_string = '{"gradient_accumulation_kwargs": {"num_steps": 2}}'

    def _get_value(self, action, arg_string):
        type_func = self._registry_get('type', action.type, action.type)
        if not callable(type_func):
            msg = _('%r is not callable')
            raise ArgumentError(action, msg % type_func)
    
        # convert the value to the appropriate type
        try:
            result = type_func(arg_string)
    
        # ArgumentTypeErrors indicate errors
        except ArgumentTypeError as err:
            msg = str(err)
            raise ArgumentError(action, msg)
    
        # TypeErrors or ValueErrors also indicate errors
        except (TypeError, ValueError):
            name = getattr(action.type, '__name__', repr(action.type))
            args = {'type': name, 'value': arg_string}
            msg = _('invalid %(type)s value: %(value)r')
>           raise ArgumentError(action, msg % args)
E           argparse.ArgumentError: argument --accelerator_config/--accelerator-config: invalid dict value: '{"gradient_accumulation_kwargs": {"num_steps": 2}}'

../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2558: ArgumentError

During handling of the above exception, another exception occurred:

self = <tests.utils.test_hf_argparser.HfArgumentParserTest testMethod=test_16_cli_input_parsing>

    @require_torch
    @patch("sys.argv", ["test.py", "--accelerator_config", '{"gradient_accumulation_kwargs": {"num_steps": 2}}'])
    def test_16_cli_input_parsing(self):
        parser = HfArgumentParser(TrainingArguments)
>       training_args = parser.parse_args_into_dataclasses()[0]
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

tests/utils/test_hf_argparser.py:491: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/transformers/hf_argparser.py:351: in parse_args_into_dataclasses
    namespace, remaining_args = self.parse_known_args(args=args)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:1914: in parse_known_args
    return self._parse_known_args2(args, namespace, intermixed=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:1945: in _parse_known_args2
    self.error(str(err))
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2650: in error
    self.exit(2, _('%(prog)s: error: %(message)s\n') % args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = HfArgumentParser(prog='test.py', usage=None, description=None, formatter_class=<class 'argparse.ArgumentDefaultsHelpFormatter'>, conflict_handler='error', add_help=True), status = 2
message = 'test.py: error: argument --accelerator_config/--accelerator-config: invalid dict value: \'{"gradient_accumulation_kwargs": {"num_steps": 2}}\'\n'

    def exit(self, status=0, message=None):
        if message:
            self._print_message(message, _sys.stderr)
>       _sys.exit(status)
E       SystemExit: 2

../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2637: SystemExit
--------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------
usage: test.py [-h] [--output_dir OUTPUT_DIR]
               [--overwrite_output_dir [OVERWRITE_OUTPUT_DIR]]
               [--do_train [DO_TRAIN]] [--do_eval [DO_EVAL]]
               [--do_predict [DO_PREDICT]] [--eval_strategy {no,steps,epoch}]
               [--prediction_loss_only [PREDICTION_LOSS_ONLY]]
               [--per_device_train_batch_size PER_DEVICE_TRAIN_BATCH_SIZE]
               [--per_device_eval_batch_size PER_DEVICE_EVAL_BATCH_SIZE]
               [--per_gpu_train_batch_size PER_GPU_TRAIN_BATCH_SIZE]
               [--per_gpu_eval_batch_size PER_GPU_EVAL_BATCH_SIZE]
               [--gradient_accumulation_steps GRADIENT_ACCUMULATION_STEPS]
               [--eval_accumulation_steps EVAL_ACCUMULATION_STEPS]
               [--eval_delay EVAL_DELAY]
               [--torch_empty_cache_steps TORCH_EMPTY_CACHE_STEPS]
               [--learning_rate LEARNING_RATE] [--weight_decay WEIGHT_DECAY]
               [--adam_beta1 ADAM_BETA1] [--adam_beta2 ADAM_BETA2]
               [--adam_epsilon ADAM_EPSILON] [--max_grad_norm MAX_GRAD_NORM]
               [--num_train_epochs NUM_TRAIN_EPOCHS] [--max_steps MAX_STEPS]
               [--lr_scheduler_type {linear,cosine,cosine_with_restarts,polynomial,constant,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau,cosine_with_min_lr,cosine_warmup_with_min_lr,warmup_stable_decay}]
               [--lr_scheduler_kwargs LR_SCHEDULER_KWARGS]
               [--warmup_ratio WARMUP_RATIO] [--warmup_steps WARMUP_STEPS]
               [--log_level {detail,debug,info,warning,error,critical,passive}]
               [--log_level_replica {detail,debug,info,warning,error,critical,passive}]
               [--log_on_each_node [LOG_ON_EACH_NODE]] [--no_log_on_each_node]
               [--logging_dir LOGGING_DIR]
               [--logging_strategy {no,steps,epoch}]
               [--logging_first_step [LOGGING_FIRST_STEP]]
               [--logging_steps LOGGING_STEPS]
               [--logging_nan_inf_filter [LOGGING_NAN_INF_FILTER]]
               [--no_logging_nan_inf_filter]
               [--save_strategy {no,steps,epoch,best}]
               [--save_steps SAVE_STEPS] [--save_total_limit SAVE_TOTAL_LIMIT]
               [--save_safetensors [SAVE_SAFETENSORS]] [--no_save_safetensors]
               [--save_on_each_node [SAVE_ON_EACH_NODE]]
               [--save_only_model [SAVE_ONLY_MODEL]]
               [--restore_callback_states_from_checkpoint [RESTORE_CALLBACK_STATES_FROM_CHECKPOINT]]
               [--no_cuda [NO_CUDA]] [--use_cpu [USE_CPU]]
               [--use_mps_device [USE_MPS_DEVICE]] [--seed SEED]
               [--data_seed DATA_SEED] [--jit_mode_eval [JIT_MODE_EVAL]]
               [--use_ipex [USE_IPEX]] [--bf16 [BF16]] [--fp16 [FP16]]
               [--fp16_opt_level FP16_OPT_LEVEL]
               [--half_precision_backend {auto,apex,cpu_amp}]
               [--bf16_full_eval [BF16_FULL_EVAL]]
               [--fp16_full_eval [FP16_FULL_EVAL]] [--tf32 TF32]
               [--local_rank LOCAL_RANK]
               [--ddp_backend {nccl,gloo,mpi,ccl,hccl,cncl,mccl}]
               [--tpu_num_cores TPU_NUM_CORES]
               [--tpu_metrics_debug [TPU_METRICS_DEBUG]]
               [--debug DEBUG [DEBUG ...]]
               [--dataloader_drop_last [DATALOADER_DROP_LAST]]
               [--eval_steps EVAL_STEPS]
               [--dataloader_num_workers DATALOADER_NUM_WORKERS]
               [--dataloader_prefetch_factor DATALOADER_PREFETCH_FACTOR]
               [--past_index PAST_INDEX] [--run_name RUN_NAME]
               [--disable_tqdm DISABLE_TQDM]
               [--remove_unused_columns [REMOVE_UNUSED_COLUMNS]]
               [--no_remove_unused_columns]
               [--label_names LABEL_NAMES [LABEL_NAMES ...]]
               [--load_best_model_at_end [LOAD_BEST_MODEL_AT_END]]
               [--metric_for_best_model METRIC_FOR_BEST_MODEL]
               [--greater_is_better GREATER_IS_BETTER]
               [--ignore_data_skip [IGNORE_DATA_SKIP]] [--fsdp FSDP]
               [--fsdp_min_num_params FSDP_MIN_NUM_PARAMS]
               [--fsdp_config FSDP_CONFIG]
               [--fsdp_transformer_layer_cls_to_wrap FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP]
               [--accelerator_config ACCELERATOR_CONFIG]
               [--deepspeed DEEPSPEED]
               [--label_smoothing_factor LABEL_SMOOTHING_FACTOR]
               [--optim {adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_torch_npu_fused,adamw_apex_fused,adafactor,adamw_anyprecision,adamw_torch_4bit,adamw_torch_8bit,ademamix,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,ademamix_8bit,lion_8bit,lion_32bit,paged_adamw_32bit,paged_adamw_8bit,paged_ademamix_32bit,paged_ademamix_8bit,paged_lion_32bit,paged_lion_8bit,rmsprop,rmsprop_bnb,rmsprop_bnb_8bit,rmsprop_bnb_32bit,galore_adamw,galore_adamw_8bit,galore_adafactor,galore_adamw_layerwise,galore_adamw_8bit_layerwise,galore_adafactor_layerwise,lomo,adalomo,grokadamw,schedule_free_radam,schedule_free_adamw,schedule_free_sgd,apollo_adamw,apollo_adamw_layerwise,stable_adamw}]
               [--optim_args OPTIM_ARGS] [--adafactor [ADAFACTOR]]
               [--group_by_length [GROUP_BY_LENGTH]]
               [--length_column_name LENGTH_COLUMN_NAME]
               [--report_to REPORT_TO]
               [--ddp_find_unused_parameters DDP_FIND_UNUSED_PARAMETERS]
               [--ddp_bucket_cap_mb DDP_BUCKET_CAP_MB]
               [--ddp_broadcast_buffers DDP_BROADCAST_BUFFERS]
               [--dataloader_pin_memory [DATALOADER_PIN_MEMORY]]
               [--no_dataloader_pin_memory]
               [--dataloader_persistent_workers [DATALOADER_PERSISTENT_WORKERS]]
               [--skip_memory_metrics [SKIP_MEMORY_METRICS]]
               [--no_skip_memory_metrics]
               [--use_legacy_prediction_loop [USE_LEGACY_PREDICTION_LOOP]]
               [--push_to_hub [PUSH_TO_HUB]]
               [--resume_from_checkpoint RESUME_FROM_CHECKPOINT]
               [--hub_model_id HUB_MODEL_ID]
               [--hub_strategy {end,every_save,checkpoint,all_checkpoints}]
               [--hub_token HUB_TOKEN] [--hub_private_repo HUB_PRIVATE_REPO]
               [--hub_always_push [HUB_ALWAYS_PUSH]]
               [--hub_revision HUB_REVISION]
               [--gradient_checkpointing [GRADIENT_CHECKPOINTING]]
               [--gradient_checkpointing_kwargs GRADIENT_CHECKPOINTING_KWARGS]
               [--include_inputs_for_metrics [INCLUDE_INPUTS_FOR_METRICS]]
               [--include_for_metrics INCLUDE_FOR_METRICS [INCLUDE_FOR_METRICS ...]]
               [--eval_do_concat_batches [EVAL_DO_CONCAT_BATCHES]]
               [--no_eval_do_concat_batches]
               [--fp16_backend {auto,apex,cpu_amp}]
               [--push_to_hub_model_id PUSH_TO_HUB_MODEL_ID]
               [--push_to_hub_organization PUSH_TO_HUB_ORGANIZATION]
               [--push_to_hub_token PUSH_TO_HUB_TOKEN]
               [--mp_parameters MP_PARAMETERS]
               [--auto_find_batch_size [AUTO_FIND_BATCH_SIZE]]
               [--full_determinism [FULL_DETERMINISM]]
               [--torchdynamo TORCHDYNAMO] [--ray_scope RAY_SCOPE]
               [--ddp_timeout DDP_TIMEOUT] [--torch_compile [TORCH_COMPILE]]
               [--torch_compile_backend TORCH_COMPILE_BACKEND]
               [--torch_compile_mode TORCH_COMPILE_MODE]
               [--include_tokens_per_second [INCLUDE_TOKENS_PER_SECOND]]
               [--include_num_input_tokens_seen [INCLUDE_NUM_INPUT_TOKENS_SEEN]]
               [--neftune_noise_alpha NEFTUNE_NOISE_ALPHA]
               [--optim_target_modules OPTIM_TARGET_MODULES]
               [--batch_eval_metrics [BATCH_EVAL_METRICS]]
               [--eval_on_start [EVAL_ON_START]]
               [--use_liger_kernel [USE_LIGER_KERNEL]]
               [--liger_kernel_config LIGER_KERNEL_CONFIG]
               [--eval_use_gather_object [EVAL_USE_GATHER_OBJECT]]
               [--average_tokens_across_devices [AVERAGE_TOKENS_ACROSS_DEVICES]]
               [--no_average_tokens_across_devices]
test.py: error: argument --accelerator_config/--accelerator-config: invalid dict value: '{"gradient_accumulation_kwargs": {"num_steps": 2}}'

===================================================================================== short test summary info ======================================================================================
FAILED tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_16_cli_input_parsing - SystemExit: 2
============================================================================= 1 failed, 16 passed, 9 warnings in 7.52s ============================================================================

But I think this would be confusing for developers because:

  • The error message isn't very clear
  • The error occurs in HfArgumentParser, not in the _get_torch_dtype function they actually modified
  • It's non-obvious that the order of Union[str, dict] vs Union[dict, str] matters

This is not reasonable, people cannot pass a dict to the parser anymore. Also, this change could not pass the tests :(

Users can still pass dicts via JSON strings (which I think is the intended behavior) and they get parsed as dicts:

python3 test.py --accelerator_config '{"gradient_accumulation_kwargs": {"num_steps": 2}}'
parse success
<class 'dict'>

Also, the tests pass with this change.

python3 -m pytest tests/utils/test_hf_argparser.py 
======================================================================================= test session starts ========================================================================================
platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0
configfile: pyproject.toml
plugins: xdist-3.8.0
collected 17 items                                                                                                                                                                                 

tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_00_basic PASSED                                                                                                                 [  5%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_01_with_default PASSED                                                                                                          [ 11%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_02_with_default_bool PASSED                                                                                                     [ 17%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_03_with_enum PASSED                                                                                                             [ 23%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_04_with_literal PASSED                                                                                                          [ 29%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_05_with_list PASSED                                                                                                             [ 35%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_06_with_optional PASSED                                                                                                         [ 41%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_07_with_required PASSED                                                                                                         [ 47%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_08_with_string_literal_annotation PASSED                                                                                        [ 52%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_09_parse_dict PASSED                                                                                                            [ 58%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_10_parse_dict_extra_key PASSED                                                                                                  [ 64%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_11_parse_json PASSED                                                                                                            [ 70%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_12_parse_yaml PASSED                                                                                                            [ 76%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_13_valid_dict_annotation PASSED                                                                                                 [ 82%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_14_valid_dict_input_parsing PASSED                                                                                              [ 88%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_15_integration_training_args PASSED                                                                                             [ 94%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_16_cli_input_parsing PASSED                                                                                                     [100%]

================================================================================== 17 passed, 1 warning in 0.22s ===================================================================================

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Jul 29, 2025

As long as dict passed as string are correctly parsed, i'm happy to have this ! WDYT @Tavish9 ?

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Tavish9
Copy link
Copy Markdown
Contributor

Tavish9 commented Jul 29, 2025

Thanks for detailed response!

This is not reasonable, people cannot pass a dict to the parser anymore. Also, this change could not pass the tests :(

Users can still pass dicts via JSON strings (which I think is the intended behavior) and they get parsed as dicts:

Apologize for misunderstanding. Yes it should work.

But I think we can handle this more gracefully.

Let me figure it out.

@Tavish9
Copy link
Copy Markdown
Contributor

Tavish9 commented Jul 29, 2025

After diving into TrainingArguments and HfArgumentParser, I found that only one case should be taken care of:
dict | str or str | dict.

Both of the main branch and this PR could not solve this problem directly. It requires re-designing HfArgumentParser._parse_dataclass_field and TrainingArguments.__post_init__.
Is this problem worth solving? @SunMarc

If yes, I would open a new PR fixing both dict str order and dict str w/o None, otherwise, I’m fine with this PR as it fixed dict str order.

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Jul 30, 2025

After diving into TrainingArguments and HfArgumentParser, I found that only one case should be taken care of:
dict | str or str | dict.

Both of the main branch and this PR could not solve this problem directly. It requires re-designing HfArgumentParser._parse_dataclass_field and TrainingArguments.post_init.
Is this problem worth solving? @SunMarc

If yes, I would open a new PR fixing both dict str order and dict str w/o None, otherwise, I’m fine with this PR as it fixed dict str order.

Can you specify in more details which case still fails with this PR ? Happy to have the PR if you are willing to spend some time on that !

@Tavish9
Copy link
Copy Markdown
Contributor

Tavish9 commented Jul 30, 2025

import unittest
from dataclasses import dataclass, field
from unittest.mock import patch

from transformers import (
    HfArgumentParser,
    TrainingArguments,
)


@dataclass
class MyTrainingArguments(TrainingArguments):
    dict1: dict | str = field(default="")
    dict2: str | dict = field(default_factory=lambda: {"test": True})

    _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["dict1", "dict2"]


class HfArgumentParserTest(unittest.TestCase):
    @patch("sys.argv", ["test.py", "--dict1", '{"gradient_accumulation_kwargs": {"num_steps": 2}}'])
    def test_01(self):
        parser = HfArgumentParser(MyTrainingArguments)
        training_args = parser.parse_args_into_dataclasses()[0]
        assert training_args.dict1 == {"gradient_accumulation_kwargs": {"num_steps": 2}} and training_args.dict2 == {"test": True}

    @patch("sys.argv", ["test.py", "--dict2", '{"gradient_accumulation_kwargs": {"num_steps": 2}}'])
    def test_02(self):
        parser = HfArgumentParser(MyTrainingArguments)
        training_args = parser.parse_args_into_dataclasses()[0]
        assert training_args.dict2 == {"gradient_accumulation_kwargs": {"num_steps": 2}} and training_args.dict1 == ""

pytest test.py

  1. main branch: --dict1: invalid dict value; --dict2: invalid dict value
  2. this PR: - AttributeError: type object 'str' has no attribute '__args__'. Did you mean: '__add__'?

A new PR need to add this unittest as well. @SunMarc , cc @st81

@st81
Copy link
Copy Markdown
Contributor Author

st81 commented Jul 30, 2025

@Tavish9
Thank you for your detailed explanation! I understand that if dict | str or str | dict types exist in TrainingArguments, both the main branch and this PR branch won't work as expected.

I'm fine with you taking time to handle this issue in a new PR!

I also understand the problem HfArgumentParser._parse_dataclass_field, it seems the current main branch assumes that dict | str or str | dict types don't exist in TrainingArguments because it filters out str types when NoneType not in Union at this line 😅.

@Tavish9
Copy link
Copy Markdown
Contributor

Tavish9 commented Jul 31, 2025

I also understand the problem HfArgumentParser._parse_dataclass_field, it seems the current main branch assumes that dict | str or str | dict types don't exist in TrainingArguments because it filters out str types when NoneType not in Union at this line 😅.

you got it. 😄

@Tavish9
Copy link
Copy Markdown
Contributor

Tavish9 commented Aug 5, 2025

@SunMarc check #39741 (comment), please.

If worth solving, I might spend some time on it.

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Aug 5, 2025

If worth solving, I might spend some time on it.

Definitely worth solving, I'm pretty sure we will get this issue in the future otherwise

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

HfArgumentParser cannot parse str for local path

4 participants