Allow for str versions of dicts based on typing#30227
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@amyeroberts figured out where to add a test, and verified that all 3 conditions raise their respective errors :) |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for enabling this!
|
|
||
| def __post_init__(self): | ||
| # Parse in args that could be `dict` sent in from the CLI as a string | ||
| for field in VALID_DICT_FIELDS: |
There was a problem hiding this comment.
I like just checking for a few possible args!
| parser = HfArgumentParser(TrainingArguments) | ||
| self.assertIsNotNone(parser) | ||
|
|
||
| def test_valid_dict_annotation(self): |
There was a problem hiding this comment.
I know on offline discussion I said we probably don't need a test to check the parsing. Seeing the implementation, i.e. TrainingArguments can be created with a field as a string representation of a dict, and not having to include the CLI, I think we can add a simple for at least one of the fields in VALID_DICT_FIELDS
e.g. something along the lines of:
def test_valid_dict_input_parsing(self):
args = TrainingArguments(
field_name='{"key": value}'
)
# Or however it's assigned in the args
self.assertEqual(args.field_name, {key: value})There was a problem hiding this comment.
Done. This also made me notice that we cast int and bools as str still, so added a helper for this (and does so to avoid literal_eval, which can be exploited)
pacman100
left a comment
There was a problem hiding this comment.
Thank you @muellerzr for adding this super useful feature to be able to pass string representation of dict as cmd arguments! 🚀
This should be very useful for the CLI support that @younesbelkada worked wrt TRL.
amyeroberts
left a comment
There was a problem hiding this comment.
Great! Thanks for enabling this and adding these tests ❤️
| elif isinstance(value, str): | ||
| # First check for bool and convert | ||
| if value.lower() in ("true", "false"): | ||
| passed_value[key] = value.lower() == "true" |
| accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}', | ||
| ) | ||
| self.assertEqual(args.accelerator_config.split_batches, True) | ||
| self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2) |
What does this PR do?
This PR adds support for passing in a string dictionary to arguments that allow a dict in the argparser. For example:
(
test.pyis just a small script reading the args from the parsers):This also fixes issues with typing of not being able to state that
deepspeedandfsdp_configcan't have typedict. Turns out it's a (very) fun setting with the argparser wrapper we have for these. Types should be declared as (for this):This is very important to get working properly!
Fixes #30204
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@amyeroberts @pacman100
Technically everything will go brr if passes, since any call to the CLI/parser happens usually with examples etc. So new tests added, but did manually verify input/output from a basic script locally