diff --git a/dargs/__init__.py b/dargs/__init__.py index 41a3f4d..858e9a4 100644 --- a/dargs/__init__.py +++ b/dargs/__init__.py @@ -1,3 +1,3 @@ -from .dargs import Argument, Variant +from .dargs import Argument, Variant, ArgumentEncoder -__all__ = ["Argument", "Variant"] \ No newline at end of file +__all__ = ["Argument", "Variant", "ArgumentEncoder"] \ No newline at end of file diff --git a/dargs/dargs.py b/dargs/dargs.py index c67885f..5ea6ad2 100644 --- a/dargs/dargs.py +++ b/dargs/dargs.py @@ -23,6 +23,7 @@ from copy import deepcopy from enum import Enum import fnmatch, re +import json INDENT = " " # doc is indented by four spaces @@ -632,4 +633,51 @@ def trim_by_pattern(argdict: dict, pattern: str, f"following reserved names: {', '.join(conflict)}") unrequired = list(filter(rem.match, argdict.keys())) for key in unrequired: - argdict.pop(key) \ No newline at end of file + argdict.pop(key) + + +class ArgumentEncoder(json.JSONEncoder): + """Extended JSON Encoder to encode Argument object: + + Examples + -------- + >>> json.dumps(some_arg, cls=ArgumentEncoder) + """ + def default(self, obj) -> Dict[str, Union[str, bool, List]]: + """Generate a dict containing argument information, making it ready to be encoded + to JSON string. + + Note + ---- + All object in the dict should be JSON serializable. + + Returns + ------- + dict: Dict + a dict containing argument information + """ + if isinstance(obj, Argument): + return { + "object": "Argument", + "name": obj.name, + "type": obj.dtype, + "optional": obj.optional, + "alias": obj.alias, + "doc": obj.doc, + "repeat": obj.repeat, + "sub_fields": obj.sub_fields, + "sub_variants": obj.sub_variants, + } + elif isinstance(obj, Variant): + return { + "object": "Variant", + "flag_name": obj.flag_name, + "optional": obj.optional, + "default_tag": obj.default_tag, + "choice_dict": obj.choice_dict, + "choice_alias": obj.choice_alias, + "doc": obj.doc, + } + elif isinstance(obj, type): + return obj.__name__ + return json.JSONEncoder.default(self, obj) diff --git a/tests/test_docgen.py b/tests/test_docgen.py index 89ba141..e5653d3 100644 --- a/tests/test_docgen.py +++ b/tests/test_docgen.py @@ -1,6 +1,7 @@ from context import dargs import unittest -from dargs import Argument, Variant +import json +from dargs import Argument, Variant, ArgumentEncoder class TestDocgen(unittest.TestCase): @@ -16,6 +17,7 @@ def test_sub_fields(self): ], doc="sub doc." * 5) ], doc="Base doc. " * 10) docstr = ca.gen_doc() + jsonstr = json.dumps(ca, cls=ArgumentEncoder) # print("\n\n"+docstr) def test_sub_repeat(self): @@ -29,6 +31,7 @@ def test_sub_repeat(self): ], doc="sub doc." * 5) ], doc="Base doc. " * 10, repeat=True) docstr = ca.gen_doc() + jsonstr = json.dumps(ca, cls=ArgumentEncoder) # print("\n\n"+docstr) def test_sub_variants(self): @@ -66,6 +69,7 @@ def test_sub_variants(self): ], optional=True, default_tag="type1", doc="another vnt") ]) docstr = ca.gen_doc(make_anchor=True) + jsonstr = json.dumps(ca, cls=ArgumentEncoder) # print("\n\n"+docstr) def test_multi_variants(self): @@ -110,6 +114,7 @@ def test_multi_variants(self): ]) ]) docstr = ca.gen_doc() + jsonstr = json.dumps(ca, cls=ArgumentEncoder) # print("\n\n"+docstr) def test_dpmd(self):