diff --git a/rest_client_gen/dynamic_typing/base.py b/rest_client_gen/dynamic_typing/base.py index fe616b9..b26d3f3 100644 --- a/rest_client_gen/dynamic_typing/base.py +++ b/rest_client_gen/dynamic_typing/base.py @@ -1,6 +1,6 @@ from typing import Iterable, List, Tuple, Union -ImportPathList = List[Tuple[str, Union[Iterable[str], str]]] +ImportPathList = List[Tuple[str, Union[Iterable[str], str, None]]] class BaseType: diff --git a/rest_client_gen/dynamic_typing/typing.py b/rest_client_gen/dynamic_typing/typing.py index 8a149d5..cd93c22 100644 --- a/rest_client_gen/dynamic_typing/typing.py +++ b/rest_client_gen/dynamic_typing/typing.py @@ -28,19 +28,31 @@ def compile_imports(imports: ImportPathList) -> str: """ Merge list of imports path and convert them into list code (string) """ - imports_map: Dict[str, Set[str]] = OrderedDict() + class_imports_map: Dict[str, Set[str]] = OrderedDict() + package_imports_set: Set[str] = set() for module, classes in filter(None, imports): - classes_set = imports_map.get(module, set()) - if isinstance(classes, str): - classes_set.add(classes) + if classes is None: + package_imports_set.add(module) else: - classes_set.update(classes) - imports_map[module] = classes_set + classes_set = class_imports_map.get(module, set()) + if isinstance(classes, str): + classes_set.add(classes) + else: + classes_set.update(classes) + class_imports_map[module] = classes_set # Sort imports by package name and sort class names of each import - imports_map = OrderedDict(sorted( - ((module, sorted(classes)) for module, classes in imports_map.items()), + class_imports_map = OrderedDict(sorted( + ((module, sorted(classes)) for module, classes in class_imports_map.items()), key=operator.itemgetter(0) )) - return "\n".join(f"from {module} import {', '.join(classes)}" for module, classes in imports_map.items()) + class_imports = "\n".join( + f"from {module} import {', '.join(classes)}" + for module, classes in class_imports_map.items() + ) + package_imports = "\n".join( + f"import {module}" + for module in sorted(package_imports_set) + ) + return "\n".join(filter(None, (package_imports, class_imports))) diff --git a/rest_client_gen/generator.py b/rest_client_gen/generator.py index b3e1718..a9a7497 100644 --- a/rest_client_gen/generator.py +++ b/rest_client_gen/generator.py @@ -2,12 +2,10 @@ from enum import Enum from typing import Any, Callable, List, Optional, Union -import inflection from unidecode import unidecode -from rest_client_gen.dynamic_typing import ComplexType, SingleType -from .dynamic_typing import (DList, DOptional, DUnion, MetaData, ModelPtr, NoneType, StringSerializable, - StringSerializableRegistry, Unknown, registry) +from .dynamic_typing import (ComplexType, DList, DOptional, DUnion, MetaData, ModelPtr, NoneType, SingleType, + StringSerializable, StringSerializableRegistry, Unknown, registry) class Hierarchy(Enum): @@ -32,10 +30,6 @@ def __str__(self): class MetadataGenerator: CONVERTER_TYPE = Optional[Callable[[str], Any]] - # TODO: sep_style: SepStyle = SepStyle.Underscore - # TODO: hierarchy: Hierarchy = Hierarchy.Nested - # TODO: fpolicy: OptionalFieldsPolicy = OptionalFieldsPolicy.Optional - def __init__(self, str_types_registry: StringSerializableRegistry = None): self.str_types_registry = str_types_registry if str_types_registry is not None else registry @@ -57,8 +51,7 @@ def _convert(self, data: dict): # ! _detect_type function can crash at some complex data sets if value is unicode with some characters (maybe German) # Crash does not produce any useful logs and can occur any time after bad string was processed # It can be reproduced on real_apis tests (openlibrary API) - fields[inflection.underscore(key)] = self._detect_type(value if not isinstance(value, str) - else unidecode(value)) + fields[key] = self._detect_type(value if not isinstance(value, str) else unidecode(value)) return fields def _detect_type(self, value, convert_dict=True) -> MetaData: diff --git a/rest_client_gen/models/__init__.py b/rest_client_gen/models/__init__.py index 50bd811..1743741 100644 --- a/rest_client_gen/models/__init__.py +++ b/rest_client_gen/models/__init__.py @@ -1,7 +1,6 @@ from typing import Dict, Generic, Iterable, List, Set, Tuple, TypeVar -from rest_client_gen.dynamic_typing import DOptional -from ..dynamic_typing import ModelMeta, ModelPtr +from ..dynamic_typing import DOptional, ModelMeta, ModelPtr Index = str T = TypeVar('T') diff --git a/rest_client_gen/models/attr.py b/rest_client_gen/models/attr.py index e69de29..9c1c99b 100644 --- a/rest_client_gen/models/attr.py +++ b/rest_client_gen/models/attr.py @@ -0,0 +1,98 @@ +from inspect import isclass +from typing import Iterable, List, Tuple + +from .base import GenericModelCodeGenerator, template +from ..dynamic_typing import DList, DOptional, ImportPathList, MetaData, ModelMeta, StringSerializable + +METADATA_FIELD_NAME = "RCG_ORIGINAL_FIELD" +KWAGRS_TEMPLATE = "{% for key, value in kwargs.items() %}" \ + "{{ key }}={{ value }}" \ + "{% if not loop.last %}, {% endif %}" \ + "{% endfor %}" + +DEFAULT_ORDER = ( + ("default", "converter", "factory"), + "*", + ("metadata",) +) + + +def sort_kwargs(kwargs: dict, ordering: Iterable[Iterable[str]] = DEFAULT_ORDER) -> dict: + sorted_dict_1 = {} + sorted_dict_2 = {} + current = sorted_dict_1 + for group in ordering: + if isinstance(group, str): + if group != "*": + raise ValueError(f"Unknown kwarg group: {group}") + current = sorted_dict_2 + else: + for item in group: + if item in kwargs: + value = kwargs.pop(item) + current[item] = value + sorted_dict = {**sorted_dict_1, **kwargs, **sorted_dict_2} + return sorted_dict + + +class AttrsModelCodeGenerator(GenericModelCodeGenerator): + ATTRS = template("attr.s" + "{% if kwargs %}" + f"({KWAGRS_TEMPLATE})" + "{% endif %}") + ATTRIB = template(f"attr.ib({KWAGRS_TEMPLATE})") + + def __init__(self, model: ModelMeta, no_meta=False, attrs_kwargs: dict = None, **kwargs): + """ + :param model: ModelMeta instance + :param no_meta: Disable generation of metadata as attrib argument + :param attrs_kwargs: kwargs for @attr.s() decorators + :param kwargs: + """ + super().__init__(model, **kwargs) + self.no_meta = no_meta + self.attrs_kwargs = attrs_kwargs or {} + + def generate(self, nested_classes: List[str] = None) -> Tuple[ImportPathList, str]: + """ + :param nested_classes: list of strings that contains classes code + :return: list of import data, class code + """ + imports, code = super().generate(nested_classes) + imports.append(('attr', None)) + return imports, code + + @property + def decorators(self) -> List[str]: + """ + :return: List of decorators code (without @) + """ + return [self.ATTRS.render(kwargs=self.attrs_kwargs)] + + def field_data(self, name: str, meta: MetaData, optional: bool) -> Tuple[ImportPathList, dict]: + """ + Form field data for template + + :param name: Field name + :param meta: Field metadata + :param optional: Is field optional + :return: imports, field data + """ + imports, data = super().field_data(name, meta, optional) + body_kwargs = {} + if optional: + meta: DOptional + if isinstance(meta.type, DList): + body_kwargs["factory"] = "list" + else: + body_kwargs["default"] = "None" + if isclass(meta.type) and issubclass(meta.type, StringSerializable): + body_kwargs["converter"] = f"optional({meta.type.__name__})" + imports.append(("attr.converter", "optional")) + elif isclass(meta) and issubclass(meta, StringSerializable): + body_kwargs["converter"] = meta.__name__ + + if not self.no_meta: + body_kwargs["metadata"] = {METADATA_FIELD_NAME: name} + data["body"] = self.ATTRIB.render(kwargs=sort_kwargs(body_kwargs)) + return imports, data diff --git a/rest_client_gen/models/base.py b/rest_client_gen/models/base.py index ea92c0d..950e5a3 100644 --- a/rest_client_gen/models/base.py +++ b/rest_client_gen/models/base.py @@ -1,11 +1,10 @@ from typing import List, Tuple, Type +import inflection from jinja2 import Template -from rest_client_gen.dynamic_typing import AbsoluteModelRef, compile_imports -from rest_client_gen.models import INDENT, ModelsStructureType, OBJECTS_DELIMITER -from . import indent, sort_fields -from ..dynamic_typing import ImportPathList, MetaData, ModelMeta, metadata_to_typing +from . import INDENT, ModelsStructureType, OBJECTS_DELIMITER, indent, sort_fields +from ..dynamic_typing import AbsoluteModelRef, ImportPathList, MetaData, ModelMeta, compile_imports, metadata_to_typing def template(pattern: str, indent: str = INDENT) -> Template: @@ -82,7 +81,7 @@ def field_data(self, name: str, meta: MetaData, optional: bool) -> Tuple[ImportP """ imports, typing = metadata_to_typing(meta) data = { - "name": name, + "name": inflection.underscore(name), "type": typing } return imports, data diff --git a/test/test_code_generation/test_attrs_generation.py b/test/test_code_generation/test_attrs_generation.py new file mode 100644 index 0000000..0d0988a --- /dev/null +++ b/test/test_code_generation/test_attrs_generation.py @@ -0,0 +1,175 @@ +from typing import Dict, List + +import pytest + +from rest_client_gen.dynamic_typing import (DList, DOptional, FloatString, IntString, ModelMeta, compile_imports) +from rest_client_gen.models import sort_fields +from rest_client_gen.models.attr import AttrsModelCodeGenerator, METADATA_FIELD_NAME, sort_kwargs +from rest_client_gen.models.base import generate_code +from test.test_code_generation.test_models_code_generator import model_factory, trim + + +def test_attrib_kwargs_sort(): + sorted_kwargs = sort_kwargs(dict( + y=2, + metadata='b', + converter='a', + default=None, + x=1, + )) + expected = ['default', 'converter', 'y', 'x', 'metadata'] + for k1, k2 in zip(sorted_kwargs.keys(), expected): + assert k1 == k2 + try: + sort_kwargs({}, ['wrong_char']) + except ValueError as e: + assert e.args[0].endswith('wrong_char') + else: + assert 0, "XPass" + + + +def field_meta(original_name): + return f"metadata={{'{METADATA_FIELD_NAME}': '{original_name}'}}" + + +# Data structure: +# pytest.param id -> { +# "model" -> (model_name, model_metadata), +# test_name -> expected, ... +# } +test_data = { + "base": { + "model": ("Test", { + "foo": int, + "bar": int, + "baz": float + }), + "fields_data": { + "foo": { + "name": "foo", + "type": "int", + "body": f"attr.ib({field_meta('foo')})" + }, + "bar": { + "name": "bar", + "type": "int", + "body": f"attr.ib({field_meta('bar')})" + }, + "baz": { + "name": "baz", + "type": "float", + "body": f"attr.ib({field_meta('baz')})" + } + }, + "fields": { + "imports": "", + "fields": [ + f"foo: int = attr.ib({field_meta('foo')})", + f"bar: int = attr.ib({field_meta('bar')})", + f"baz: float = attr.ib({field_meta('baz')})", + ] + }, + "generated": trim(f""" + import attr + + + @attr.s + class Test: + foo: int = attr.ib({field_meta('foo')}) + bar: int = attr.ib({field_meta('bar')}) + baz: float = attr.ib({field_meta('baz')}) + """) + }, + "complex": { + "model": ("Test", { + "foo": int, + "baz": DOptional(DList(DList(str))), + "bar": DOptional(IntString), + "qwerty": FloatString, + "asdfg": DOptional(int) + }), + "fields_data": { + "foo": { + "name": "foo", + "type": "int", + "body": f"attr.ib({field_meta('foo')})" + }, + "baz": { + "name": "baz", + "type": "Optional[List[List[str]]]", + "body": f"attr.ib(factory=list, {field_meta('baz')})" + }, + "bar": { + "name": "bar", + "type": "Optional[IntString]", + "body": f"attr.ib(default=None, converter=optional(IntString), {field_meta('bar')})" + }, + "qwerty": { + "name": "qwerty", + "type": "FloatString", + "body": f"attr.ib(converter=FloatString, {field_meta('qwerty')})" + }, + "asdfg": { + "name": "asdfg", + "type": "Optional[int]", + "body": f"attr.ib(default=None, {field_meta('asdfg')})" + } + }, + "generated": trim(f""" + import attr + from attr.converter import optional + from rest_client_gen.dynamic_typing.string_serializable import FloatString, IntString + from typing import List, Optional + + + @attr.s + class Test: + foo: int = attr.ib({field_meta('foo')}) + qwerty: FloatString = attr.ib(converter=FloatString, {field_meta('qwerty')}) + baz: Optional[List[List[str]]] = attr.ib(factory=list, {field_meta('baz')}) + bar: Optional[IntString] = attr.ib(default=None, converter=optional(IntString), {field_meta('bar')}) + asdfg: Optional[int] = attr.ib(default=None, {field_meta('asdfg')}) + """) + } +} + +test_data_unzip = { + test: [ + pytest.param( + model_factory(*data["model"]), + data[test], + id=id + ) + for id, data in test_data.items() + if test in data + ] + for test in ("fields_data", "fields", "generated") +} + + +@pytest.mark.parametrize("value,expected", test_data_unzip["fields_data"]) +def test_fields_data_attr(value: ModelMeta, expected: Dict[str, dict]): + gen = AttrsModelCodeGenerator(value) + required, optional = sort_fields(value) + for is_optional, fields in enumerate((required, optional)): + for field in fields: + field_imports, data = gen.field_data(field, value.type[field], bool(is_optional)) + assert data == expected[field] + + +@pytest.mark.parametrize("value,expected", test_data_unzip["fields"]) +def test_fields_attr(value: ModelMeta, expected: dict): + expected_imports: str = expected["imports"] + expected_fields: List[str] = expected["fields"] + gen = AttrsModelCodeGenerator(value) + imports, fields = gen.fields + imports = compile_imports(imports) + assert imports == expected_imports + assert fields == expected_fields + + +@pytest.mark.parametrize("value,expected", test_data_unzip["generated"]) +def test_generated_attr(value: ModelMeta, expected: str): + generated = generate_code(([{"model": value, "nested": []}], {}), AttrsModelCodeGenerator) + assert generated.rstrip() == expected, generated diff --git a/test/test_code_generation/test_models_code_generator.py b/test/test_code_generation/test_models_code_generator.py index 1c10fd3..4807dd7 100644 --- a/test/test_code_generation/test_models_code_generator.py +++ b/test/test_code_generation/test_models_code_generator.py @@ -7,6 +7,9 @@ from rest_client_gen.models import indent, sort_fields from rest_client_gen.models.base import GenericModelCodeGenerator, generate_code +# Data structure: +# (string, indent lvl, indent string) +# result test_indent_data = [ pytest.param( ("1", 1, " " * 4), diff --git a/test/test_code_generation/test_typing.py b/test/test_code_generation/test_typing.py index 01399cb..41eda00 100644 --- a/test/test_code_generation/test_typing.py +++ b/test/test_code_generation/test_typing.py @@ -23,7 +23,24 @@ def test_metadata_to_typing_with_dict(): "from pytest import param\n" "from typing import Any, List, Tuple", id="basic" - ) + ), + pytest.param( + [ + ('typing', ('List', 'Any')), + [], + ('typing', None), + [], + ('pytest', 'param'), + ('typing', ('List', 'Tuple')), + ('attr', None), + ('typing', None), + ], + "import attr\n" + "import typing\n" + "from pytest import param\n" + "from typing import Any, List, Tuple", + id="basic" + ), ] diff --git a/testing_tools/real_apis/f1.py b/testing_tools/real_apis/f1.py index 14a6aae..cb88752 100644 --- a/testing_tools/real_apis/f1.py +++ b/testing_tools/real_apis/f1.py @@ -7,7 +7,8 @@ from rest_client_gen.generator import MetadataGenerator from rest_client_gen.models import compose_models -from rest_client_gen.models.base import GenericModelCodeGenerator, generate_code +from rest_client_gen.models.attr import AttrsModelCodeGenerator +from rest_client_gen.models.base import generate_code from rest_client_gen.registry import ModelRegistry from rest_client_gen.utils import json_format from testing_tools.pprint_meta_data import pretty_format_meta @@ -58,7 +59,7 @@ def main(): print('\n', json_format([structure[0], {str(a): str(b) for a, b in structure[1].items()}])) print("=" * 20) - print(generate_code(structure, GenericModelCodeGenerator)) + print(generate_code(structure, AttrsModelCodeGenerator)) if __name__ == '__main__': diff --git a/testing_tools/real_apis/pathofexile.py b/testing_tools/real_apis/pathofexile.py index d68e490..614a9d9 100644 --- a/testing_tools/real_apis/pathofexile.py +++ b/testing_tools/real_apis/pathofexile.py @@ -5,7 +5,8 @@ from rest_client_gen.generator import MetadataGenerator from rest_client_gen.models import compose_models -from rest_client_gen.models.base import GenericModelCodeGenerator, generate_code +from rest_client_gen.models.attr import AttrsModelCodeGenerator +from rest_client_gen.models.base import generate_code from rest_client_gen.registry import ModelRegistry from rest_client_gen.utils import json_format from testing_tools.pprint_meta_data import pretty_format_meta @@ -36,7 +37,7 @@ def main(): print('\n', json_format([structure[0], {str(a): str(b) for a, b in structure[1].items()}])) print("=" * 20) - print(generate_code(structure, GenericModelCodeGenerator)) + print(generate_code(structure, AttrsModelCodeGenerator)) if __name__ == '__main__':