Fix HfArgumentParser to filter out dict types from Union#39741
Fix HfArgumentParser to filter out dict types from Union#39741st81 wants to merge 1 commit intohuggingface:mainfrom
Conversation
There was a problem hiding this comment.
This prevents runtime errors with argparse, which does not support parsing arguments as dicts.
Please show me a smallest example that could fail on the main branch.
could break if new
Union[str, dict]annotations are introduced
this would not happen since it cannot pass the tests I added.
This change ensures that dict types are always filtered out
This is not reasonable, people cannot pass a dict to the parser anymore. Also, this change could not pass the tests :(
|
@Tavish9 Thank you for your feedback!
For example, if we add the new argument def _get_torch_dtype(
cls,
torch_dtype: Optional[Union[str, torch.dtype, dict]],
checkpoint_files: Optional[list[str]],
config: PretrainedConfig,
sharded_metadata: Optional[dict],
state_dict: Optional[dict],
weights_only: bool,
new_arg: Optional[Union[str, dict]] = None, # <-- this is the new argument
) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:the following code snippet will fail on the main branch. from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
)
def main():
parser = HfArgumentParser(TrainingArguments)
training_args = parser.parse_args_into_dataclasses()[0]
print("parse success")
print(type(training_args.accelerator_config.gradient_accumulation_kwargs))
if __name__ == "__main__":
main()python3 test.py --accelerator_config '{"gradient_accumulation_kwargs": {"num_steps": 2}}'
test.py: error: argument --accelerator_config/--accelerator-config: invalid dict value: '{"gradient_accumulation_kwargs": {"num_steps": 2}}'
You're absolutely right! The existing tests The error message is long so I enclose it in a details tag.python3 -m pytest tests/utils/test_hf_argparser.py
__________________________________________________________________________ HfArgumentParserTest.test_16_cli_input_parsing __________________________________________________________________________
self = HfArgumentParser(prog='test.py', usage=None, description=None, formatter_class=<class 'argparse.ArgumentDefaultsHelpFormatter'>, conflict_handler='error', add_help=True)
action = _StoreAction(option_strings=['--accelerator_config', '--accelerator-config'], dest='accelerator_config', nargs=None, c...ccelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`.', metavar=None)
arg_string = '{"gradient_accumulation_kwargs": {"num_steps": 2}}'
def _get_value(self, action, arg_string):
type_func = self._registry_get('type', action.type, action.type)
if not callable(type_func):
msg = _('%r is not callable')
raise ArgumentError(action, msg % type_func)
# convert the value to the appropriate type
try:
> result = type_func(arg_string)
^^^^^^^^^^^^^^^^^^^^^
E ValueError: dictionary update sequence element #0 has length 1; 2 is required
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2546: ValueError
During handling of the above exception, another exception occurred:
self = HfArgumentParser(prog='test.py', usage=None, description=None, formatter_class=<class 'argparse.ArgumentDefaultsHelpFormatter'>, conflict_handler='error', add_help=True)
args = ['--accelerator_config', '{"gradient_accumulation_kwargs": {"num_steps": 2}}']
namespace = Namespace(output_dir=None, overwrite_output_dir=False, do_train=False, do_eval=False, do_predict=False, eval_strategy=...se, use_liger_kernel=False, liger_kernel_config=None, eval_use_gather_object=False, average_tokens_across_devices=True)
intermixed = False
def _parse_known_args2(self, args, namespace, intermixed):
if args is None:
# args default to the system args
args = _sys.argv[1:]
else:
# make sure that args are mutable
args = list(args)
# default Namespace built from parser defaults
if namespace is None:
namespace = Namespace()
# add any action defaults that aren't present
for action in self._actions:
if action.dest is not SUPPRESS:
if not hasattr(namespace, action.dest):
if action.default is not SUPPRESS:
setattr(namespace, action.dest, action.default)
# add any parser defaults that aren't present
for dest in self._defaults:
if not hasattr(namespace, dest):
setattr(namespace, dest, self._defaults[dest])
# parse the arguments and exit if there are any errors
if self.exit_on_error:
try:
> namespace, args = self._parse_known_args(args, namespace, intermixed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:1943:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2184: in _parse_known_args
start_index = consume_optional(start_index)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2113: in consume_optional
take_action(action, args, option_string)
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2003: in take_action
argument_values = self._get_values(action, argument_strings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2513: in _get_values
value = self._get_value(action, arg_string)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = HfArgumentParser(prog='test.py', usage=None, description=None, formatter_class=<class 'argparse.ArgumentDefaultsHelpFormatter'>, conflict_handler='error', add_help=True)
action = _StoreAction(option_strings=['--accelerator_config', '--accelerator-config'], dest='accelerator_config', nargs=None, c...ccelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`.', metavar=None)
arg_string = '{"gradient_accumulation_kwargs": {"num_steps": 2}}'
def _get_value(self, action, arg_string):
type_func = self._registry_get('type', action.type, action.type)
if not callable(type_func):
msg = _('%r is not callable')
raise ArgumentError(action, msg % type_func)
# convert the value to the appropriate type
try:
result = type_func(arg_string)
# ArgumentTypeErrors indicate errors
except ArgumentTypeError as err:
msg = str(err)
raise ArgumentError(action, msg)
# TypeErrors or ValueErrors also indicate errors
except (TypeError, ValueError):
name = getattr(action.type, '__name__', repr(action.type))
args = {'type': name, 'value': arg_string}
msg = _('invalid %(type)s value: %(value)r')
> raise ArgumentError(action, msg % args)
E argparse.ArgumentError: argument --accelerator_config/--accelerator-config: invalid dict value: '{"gradient_accumulation_kwargs": {"num_steps": 2}}'
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2558: ArgumentError
During handling of the above exception, another exception occurred:
self = <tests.utils.test_hf_argparser.HfArgumentParserTest testMethod=test_16_cli_input_parsing>
@require_torch
@patch("sys.argv", ["test.py", "--accelerator_config", '{"gradient_accumulation_kwargs": {"num_steps": 2}}'])
def test_16_cli_input_parsing(self):
parser = HfArgumentParser(TrainingArguments)
> training_args = parser.parse_args_into_dataclasses()[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tests/utils/test_hf_argparser.py:491:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/hf_argparser.py:351: in parse_args_into_dataclasses
namespace, remaining_args = self.parse_known_args(args=args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:1914: in parse_known_args
return self._parse_known_args2(args, namespace, intermixed=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:1945: in _parse_known_args2
self.error(str(err))
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2650: in error
self.exit(2, _('%(prog)s: error: %(message)s\n') % args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = HfArgumentParser(prog='test.py', usage=None, description=None, formatter_class=<class 'argparse.ArgumentDefaultsHelpFormatter'>, conflict_handler='error', add_help=True), status = 2
message = 'test.py: error: argument --accelerator_config/--accelerator-config: invalid dict value: \'{"gradient_accumulation_kwargs": {"num_steps": 2}}\'\n'
def exit(self, status=0, message=None):
if message:
self._print_message(message, _sys.stderr)
> _sys.exit(status)
E SystemExit: 2
../../../.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/argparse.py:2637: SystemExit
--------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------
usage: test.py [-h] [--output_dir OUTPUT_DIR]
[--overwrite_output_dir [OVERWRITE_OUTPUT_DIR]]
[--do_train [DO_TRAIN]] [--do_eval [DO_EVAL]]
[--do_predict [DO_PREDICT]] [--eval_strategy {no,steps,epoch}]
[--prediction_loss_only [PREDICTION_LOSS_ONLY]]
[--per_device_train_batch_size PER_DEVICE_TRAIN_BATCH_SIZE]
[--per_device_eval_batch_size PER_DEVICE_EVAL_BATCH_SIZE]
[--per_gpu_train_batch_size PER_GPU_TRAIN_BATCH_SIZE]
[--per_gpu_eval_batch_size PER_GPU_EVAL_BATCH_SIZE]
[--gradient_accumulation_steps GRADIENT_ACCUMULATION_STEPS]
[--eval_accumulation_steps EVAL_ACCUMULATION_STEPS]
[--eval_delay EVAL_DELAY]
[--torch_empty_cache_steps TORCH_EMPTY_CACHE_STEPS]
[--learning_rate LEARNING_RATE] [--weight_decay WEIGHT_DECAY]
[--adam_beta1 ADAM_BETA1] [--adam_beta2 ADAM_BETA2]
[--adam_epsilon ADAM_EPSILON] [--max_grad_norm MAX_GRAD_NORM]
[--num_train_epochs NUM_TRAIN_EPOCHS] [--max_steps MAX_STEPS]
[--lr_scheduler_type {linear,cosine,cosine_with_restarts,polynomial,constant,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau,cosine_with_min_lr,cosine_warmup_with_min_lr,warmup_stable_decay}]
[--lr_scheduler_kwargs LR_SCHEDULER_KWARGS]
[--warmup_ratio WARMUP_RATIO] [--warmup_steps WARMUP_STEPS]
[--log_level {detail,debug,info,warning,error,critical,passive}]
[--log_level_replica {detail,debug,info,warning,error,critical,passive}]
[--log_on_each_node [LOG_ON_EACH_NODE]] [--no_log_on_each_node]
[--logging_dir LOGGING_DIR]
[--logging_strategy {no,steps,epoch}]
[--logging_first_step [LOGGING_FIRST_STEP]]
[--logging_steps LOGGING_STEPS]
[--logging_nan_inf_filter [LOGGING_NAN_INF_FILTER]]
[--no_logging_nan_inf_filter]
[--save_strategy {no,steps,epoch,best}]
[--save_steps SAVE_STEPS] [--save_total_limit SAVE_TOTAL_LIMIT]
[--save_safetensors [SAVE_SAFETENSORS]] [--no_save_safetensors]
[--save_on_each_node [SAVE_ON_EACH_NODE]]
[--save_only_model [SAVE_ONLY_MODEL]]
[--restore_callback_states_from_checkpoint [RESTORE_CALLBACK_STATES_FROM_CHECKPOINT]]
[--no_cuda [NO_CUDA]] [--use_cpu [USE_CPU]]
[--use_mps_device [USE_MPS_DEVICE]] [--seed SEED]
[--data_seed DATA_SEED] [--jit_mode_eval [JIT_MODE_EVAL]]
[--use_ipex [USE_IPEX]] [--bf16 [BF16]] [--fp16 [FP16]]
[--fp16_opt_level FP16_OPT_LEVEL]
[--half_precision_backend {auto,apex,cpu_amp}]
[--bf16_full_eval [BF16_FULL_EVAL]]
[--fp16_full_eval [FP16_FULL_EVAL]] [--tf32 TF32]
[--local_rank LOCAL_RANK]
[--ddp_backend {nccl,gloo,mpi,ccl,hccl,cncl,mccl}]
[--tpu_num_cores TPU_NUM_CORES]
[--tpu_metrics_debug [TPU_METRICS_DEBUG]]
[--debug DEBUG [DEBUG ...]]
[--dataloader_drop_last [DATALOADER_DROP_LAST]]
[--eval_steps EVAL_STEPS]
[--dataloader_num_workers DATALOADER_NUM_WORKERS]
[--dataloader_prefetch_factor DATALOADER_PREFETCH_FACTOR]
[--past_index PAST_INDEX] [--run_name RUN_NAME]
[--disable_tqdm DISABLE_TQDM]
[--remove_unused_columns [REMOVE_UNUSED_COLUMNS]]
[--no_remove_unused_columns]
[--label_names LABEL_NAMES [LABEL_NAMES ...]]
[--load_best_model_at_end [LOAD_BEST_MODEL_AT_END]]
[--metric_for_best_model METRIC_FOR_BEST_MODEL]
[--greater_is_better GREATER_IS_BETTER]
[--ignore_data_skip [IGNORE_DATA_SKIP]] [--fsdp FSDP]
[--fsdp_min_num_params FSDP_MIN_NUM_PARAMS]
[--fsdp_config FSDP_CONFIG]
[--fsdp_transformer_layer_cls_to_wrap FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP]
[--accelerator_config ACCELERATOR_CONFIG]
[--deepspeed DEEPSPEED]
[--label_smoothing_factor LABEL_SMOOTHING_FACTOR]
[--optim {adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_torch_npu_fused,adamw_apex_fused,adafactor,adamw_anyprecision,adamw_torch_4bit,adamw_torch_8bit,ademamix,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,ademamix_8bit,lion_8bit,lion_32bit,paged_adamw_32bit,paged_adamw_8bit,paged_ademamix_32bit,paged_ademamix_8bit,paged_lion_32bit,paged_lion_8bit,rmsprop,rmsprop_bnb,rmsprop_bnb_8bit,rmsprop_bnb_32bit,galore_adamw,galore_adamw_8bit,galore_adafactor,galore_adamw_layerwise,galore_adamw_8bit_layerwise,galore_adafactor_layerwise,lomo,adalomo,grokadamw,schedule_free_radam,schedule_free_adamw,schedule_free_sgd,apollo_adamw,apollo_adamw_layerwise,stable_adamw}]
[--optim_args OPTIM_ARGS] [--adafactor [ADAFACTOR]]
[--group_by_length [GROUP_BY_LENGTH]]
[--length_column_name LENGTH_COLUMN_NAME]
[--report_to REPORT_TO]
[--ddp_find_unused_parameters DDP_FIND_UNUSED_PARAMETERS]
[--ddp_bucket_cap_mb DDP_BUCKET_CAP_MB]
[--ddp_broadcast_buffers DDP_BROADCAST_BUFFERS]
[--dataloader_pin_memory [DATALOADER_PIN_MEMORY]]
[--no_dataloader_pin_memory]
[--dataloader_persistent_workers [DATALOADER_PERSISTENT_WORKERS]]
[--skip_memory_metrics [SKIP_MEMORY_METRICS]]
[--no_skip_memory_metrics]
[--use_legacy_prediction_loop [USE_LEGACY_PREDICTION_LOOP]]
[--push_to_hub [PUSH_TO_HUB]]
[--resume_from_checkpoint RESUME_FROM_CHECKPOINT]
[--hub_model_id HUB_MODEL_ID]
[--hub_strategy {end,every_save,checkpoint,all_checkpoints}]
[--hub_token HUB_TOKEN] [--hub_private_repo HUB_PRIVATE_REPO]
[--hub_always_push [HUB_ALWAYS_PUSH]]
[--hub_revision HUB_REVISION]
[--gradient_checkpointing [GRADIENT_CHECKPOINTING]]
[--gradient_checkpointing_kwargs GRADIENT_CHECKPOINTING_KWARGS]
[--include_inputs_for_metrics [INCLUDE_INPUTS_FOR_METRICS]]
[--include_for_metrics INCLUDE_FOR_METRICS [INCLUDE_FOR_METRICS ...]]
[--eval_do_concat_batches [EVAL_DO_CONCAT_BATCHES]]
[--no_eval_do_concat_batches]
[--fp16_backend {auto,apex,cpu_amp}]
[--push_to_hub_model_id PUSH_TO_HUB_MODEL_ID]
[--push_to_hub_organization PUSH_TO_HUB_ORGANIZATION]
[--push_to_hub_token PUSH_TO_HUB_TOKEN]
[--mp_parameters MP_PARAMETERS]
[--auto_find_batch_size [AUTO_FIND_BATCH_SIZE]]
[--full_determinism [FULL_DETERMINISM]]
[--torchdynamo TORCHDYNAMO] [--ray_scope RAY_SCOPE]
[--ddp_timeout DDP_TIMEOUT] [--torch_compile [TORCH_COMPILE]]
[--torch_compile_backend TORCH_COMPILE_BACKEND]
[--torch_compile_mode TORCH_COMPILE_MODE]
[--include_tokens_per_second [INCLUDE_TOKENS_PER_SECOND]]
[--include_num_input_tokens_seen [INCLUDE_NUM_INPUT_TOKENS_SEEN]]
[--neftune_noise_alpha NEFTUNE_NOISE_ALPHA]
[--optim_target_modules OPTIM_TARGET_MODULES]
[--batch_eval_metrics [BATCH_EVAL_METRICS]]
[--eval_on_start [EVAL_ON_START]]
[--use_liger_kernel [USE_LIGER_KERNEL]]
[--liger_kernel_config LIGER_KERNEL_CONFIG]
[--eval_use_gather_object [EVAL_USE_GATHER_OBJECT]]
[--average_tokens_across_devices [AVERAGE_TOKENS_ACROSS_DEVICES]]
[--no_average_tokens_across_devices]
test.py: error: argument --accelerator_config/--accelerator-config: invalid dict value: '{"gradient_accumulation_kwargs": {"num_steps": 2}}'
===================================================================================== short test summary info ======================================================================================
FAILED tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_16_cli_input_parsing - SystemExit: 2
============================================================================= 1 failed, 16 passed, 9 warnings in 7.52s ============================================================================But I think this would be confusing for developers because:
Users can still pass dicts via JSON strings (which I think is the intended behavior) and they get parsed as dicts: python3 test.py --accelerator_config '{"gradient_accumulation_kwargs": {"num_steps": 2}}'
parse success
<class 'dict'>Also, the tests pass with this change. python3 -m pytest tests/utils/test_hf_argparser.py
======================================================================================= test session starts ========================================================================================
platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0
configfile: pyproject.toml
plugins: xdist-3.8.0
collected 17 items
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_00_basic PASSED [ 5%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_01_with_default PASSED [ 11%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_02_with_default_bool PASSED [ 17%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_03_with_enum PASSED [ 23%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_04_with_literal PASSED [ 29%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_05_with_list PASSED [ 35%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_06_with_optional PASSED [ 41%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_07_with_required PASSED [ 47%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_08_with_string_literal_annotation PASSED [ 52%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_09_parse_dict PASSED [ 58%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_10_parse_dict_extra_key PASSED [ 64%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_11_parse_json PASSED [ 70%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_12_parse_yaml PASSED [ 76%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_13_valid_dict_annotation PASSED [ 82%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_14_valid_dict_input_parsing PASSED [ 88%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_15_integration_training_args PASSED [ 94%]
tests/utils/test_hf_argparser.py::HfArgumentParserTest::test_16_cli_input_parsing PASSED [100%]
================================================================================== 17 passed, 1 warning in 0.22s =================================================================================== |
|
As long as |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thanks for detailed response!
Apologize for misunderstanding. Yes it should work. But I think we can handle this more gracefully. Let me figure it out. |
|
After diving into Both of the main branch and this PR could not solve this problem directly. It requires re-designing If yes, I would open a new PR fixing both |
Can you specify in more details which case still fails with this PR ? Happy to have the PR if you are willing to spend some time on that ! |
import unittest
from dataclasses import dataclass, field
from unittest.mock import patch
from transformers import (
HfArgumentParser,
TrainingArguments,
)
@dataclass
class MyTrainingArguments(TrainingArguments):
dict1: dict | str = field(default="")
dict2: str | dict = field(default_factory=lambda: {"test": True})
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["dict1", "dict2"]
class HfArgumentParserTest(unittest.TestCase):
@patch("sys.argv", ["test.py", "--dict1", '{"gradient_accumulation_kwargs": {"num_steps": 2}}'])
def test_01(self):
parser = HfArgumentParser(MyTrainingArguments)
training_args = parser.parse_args_into_dataclasses()[0]
assert training_args.dict1 == {"gradient_accumulation_kwargs": {"num_steps": 2}} and training_args.dict2 == {"test": True}
@patch("sys.argv", ["test.py", "--dict2", '{"gradient_accumulation_kwargs": {"num_steps": 2}}'])
def test_02(self):
parser = HfArgumentParser(MyTrainingArguments)
training_args = parser.parse_args_into_dataclasses()[0]
assert training_args.dict2 == {"gradient_accumulation_kwargs": {"num_steps": 2}} and training_args.dict1 == ""
A new PR need to add this unittest as well. @SunMarc , cc @st81 |
|
@Tavish9 I'm fine with you taking time to handle this issue in a new PR! I also understand the problem |
you got it. 😄 |
|
@SunMarc check #39741 (comment), please. If worth solving, I might spend some time on it. |
Definitely worth solving, I'm pretty sure we will get this issue in the future otherwise |
What does this PR do?
This PR updates
HfArgumentParserto filter out dict types fromUnionannotations. This prevents runtime errors with argparse, which does not support parsing arguments as dicts. Previous work (PR #39467) reorderedUnion[str, dict]toUnion[dict, str]throughout the codebase, but this approach is fragile and could break if newUnion[str, dict]annotations are introduced. This change ensures that dict types are always filtered out, making the parser more robust and maintainable.Fixes #39462
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
(same as previous PR)