Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rest_client_gen/dynamic_typing/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Iterable, List, Tuple, Union

ImportPathList = List[Tuple[str, Union[Iterable[str], str]]]
ImportPathList = List[Tuple[str, Union[Iterable[str], str, None]]]


class BaseType:
Expand Down
30 changes: 21 additions & 9 deletions rest_client_gen/dynamic_typing/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,31 @@ def compile_imports(imports: ImportPathList) -> str:
"""
Merge list of imports path and convert them into list code (string)
"""
imports_map: Dict[str, Set[str]] = OrderedDict()
class_imports_map: Dict[str, Set[str]] = OrderedDict()
package_imports_set: Set[str] = set()
for module, classes in filter(None, imports):
classes_set = imports_map.get(module, set())
if isinstance(classes, str):
classes_set.add(classes)
if classes is None:
package_imports_set.add(module)
else:
classes_set.update(classes)
imports_map[module] = classes_set
classes_set = class_imports_map.get(module, set())
if isinstance(classes, str):
classes_set.add(classes)
else:
classes_set.update(classes)
class_imports_map[module] = classes_set

# Sort imports by package name and sort class names of each import
imports_map = OrderedDict(sorted(
((module, sorted(classes)) for module, classes in imports_map.items()),
class_imports_map = OrderedDict(sorted(
((module, sorted(classes)) for module, classes in class_imports_map.items()),
key=operator.itemgetter(0)
))

return "\n".join(f"from {module} import {', '.join(classes)}" for module, classes in imports_map.items())
class_imports = "\n".join(
f"from {module} import {', '.join(classes)}"
for module, classes in class_imports_map.items()
)
package_imports = "\n".join(
f"import {module}"
for module in sorted(package_imports_set)
)
return "\n".join(filter(None, (package_imports, class_imports)))
13 changes: 3 additions & 10 deletions rest_client_gen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from enum import Enum
from typing import Any, Callable, List, Optional, Union

import inflection
from unidecode import unidecode

from rest_client_gen.dynamic_typing import ComplexType, SingleType
from .dynamic_typing import (DList, DOptional, DUnion, MetaData, ModelPtr, NoneType, StringSerializable,
StringSerializableRegistry, Unknown, registry)
from .dynamic_typing import (ComplexType, DList, DOptional, DUnion, MetaData, ModelPtr, NoneType, SingleType,
StringSerializable, StringSerializableRegistry, Unknown, registry)


class Hierarchy(Enum):
Expand All @@ -32,10 +30,6 @@ def __str__(self):
class MetadataGenerator:
CONVERTER_TYPE = Optional[Callable[[str], Any]]

# TODO: sep_style: SepStyle = SepStyle.Underscore
# TODO: hierarchy: Hierarchy = Hierarchy.Nested
# TODO: fpolicy: OptionalFieldsPolicy = OptionalFieldsPolicy.Optional

def __init__(self, str_types_registry: StringSerializableRegistry = None):
self.str_types_registry = str_types_registry if str_types_registry is not None else registry

Expand All @@ -57,8 +51,7 @@ def _convert(self, data: dict):
# ! _detect_type function can crash at some complex data sets if value is unicode with some characters (maybe German)
# Crash does not produce any useful logs and can occur any time after bad string was processed
# It can be reproduced on real_apis tests (openlibrary API)
fields[inflection.underscore(key)] = self._detect_type(value if not isinstance(value, str)
else unidecode(value))
fields[key] = self._detect_type(value if not isinstance(value, str) else unidecode(value))
return fields

def _detect_type(self, value, convert_dict=True) -> MetaData:
Expand Down
3 changes: 1 addition & 2 deletions rest_client_gen/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict, Generic, Iterable, List, Set, Tuple, TypeVar

from rest_client_gen.dynamic_typing import DOptional
from ..dynamic_typing import ModelMeta, ModelPtr
from ..dynamic_typing import DOptional, ModelMeta, ModelPtr

Index = str
T = TypeVar('T')
Expand Down
98 changes: 98 additions & 0 deletions rest_client_gen/models/attr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from inspect import isclass
from typing import Iterable, List, Tuple

from .base import GenericModelCodeGenerator, template
from ..dynamic_typing import DList, DOptional, ImportPathList, MetaData, ModelMeta, StringSerializable

METADATA_FIELD_NAME = "RCG_ORIGINAL_FIELD"
KWAGRS_TEMPLATE = "{% for key, value in kwargs.items() %}" \
"{{ key }}={{ value }}" \
"{% if not loop.last %}, {% endif %}" \
"{% endfor %}"

DEFAULT_ORDER = (
("default", "converter", "factory"),
"*",
("metadata",)
)


def sort_kwargs(kwargs: dict, ordering: Iterable[Iterable[str]] = DEFAULT_ORDER) -> dict:
sorted_dict_1 = {}
sorted_dict_2 = {}
current = sorted_dict_1
for group in ordering:
if isinstance(group, str):
if group != "*":
raise ValueError(f"Unknown kwarg group: {group}")
current = sorted_dict_2
else:
for item in group:
if item in kwargs:
value = kwargs.pop(item)
current[item] = value
sorted_dict = {**sorted_dict_1, **kwargs, **sorted_dict_2}
return sorted_dict


class AttrsModelCodeGenerator(GenericModelCodeGenerator):
ATTRS = template("attr.s"
"{% if kwargs %}"
f"({KWAGRS_TEMPLATE})"
"{% endif %}")
ATTRIB = template(f"attr.ib({KWAGRS_TEMPLATE})")

def __init__(self, model: ModelMeta, no_meta=False, attrs_kwargs: dict = None, **kwargs):
"""
:param model: ModelMeta instance
:param no_meta: Disable generation of metadata as attrib argument
:param attrs_kwargs: kwargs for @attr.s() decorators
:param kwargs:
"""
super().__init__(model, **kwargs)
self.no_meta = no_meta
self.attrs_kwargs = attrs_kwargs or {}

def generate(self, nested_classes: List[str] = None) -> Tuple[ImportPathList, str]:
"""
:param nested_classes: list of strings that contains classes code
:return: list of import data, class code
"""
imports, code = super().generate(nested_classes)
imports.append(('attr', None))
return imports, code

@property
def decorators(self) -> List[str]:
"""
:return: List of decorators code (without @)
"""
return [self.ATTRS.render(kwargs=self.attrs_kwargs)]

def field_data(self, name: str, meta: MetaData, optional: bool) -> Tuple[ImportPathList, dict]:
"""
Form field data for template

:param name: Field name
:param meta: Field metadata
:param optional: Is field optional
:return: imports, field data
"""
imports, data = super().field_data(name, meta, optional)
body_kwargs = {}
if optional:
meta: DOptional
if isinstance(meta.type, DList):
body_kwargs["factory"] = "list"
else:
body_kwargs["default"] = "None"
if isclass(meta.type) and issubclass(meta.type, StringSerializable):
body_kwargs["converter"] = f"optional({meta.type.__name__})"
imports.append(("attr.converter", "optional"))
elif isclass(meta) and issubclass(meta, StringSerializable):
body_kwargs["converter"] = meta.__name__

if not self.no_meta:
body_kwargs["metadata"] = {METADATA_FIELD_NAME: name}
data["body"] = self.ATTRIB.render(kwargs=sort_kwargs(body_kwargs))
return imports, data
9 changes: 4 additions & 5 deletions rest_client_gen/models/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import List, Tuple, Type

import inflection
from jinja2 import Template

from rest_client_gen.dynamic_typing import AbsoluteModelRef, compile_imports
from rest_client_gen.models import INDENT, ModelsStructureType, OBJECTS_DELIMITER
from . import indent, sort_fields
from ..dynamic_typing import ImportPathList, MetaData, ModelMeta, metadata_to_typing
from . import INDENT, ModelsStructureType, OBJECTS_DELIMITER, indent, sort_fields
from ..dynamic_typing import AbsoluteModelRef, ImportPathList, MetaData, ModelMeta, compile_imports, metadata_to_typing


def template(pattern: str, indent: str = INDENT) -> Template:
Expand Down Expand Up @@ -82,7 +81,7 @@ def field_data(self, name: str, meta: MetaData, optional: bool) -> Tuple[ImportP
"""
imports, typing = metadata_to_typing(meta)
data = {
"name": name,
"name": inflection.underscore(name),
"type": typing
}
return imports, data
Expand Down
175 changes: 175 additions & 0 deletions test/test_code_generation/test_attrs_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from typing import Dict, List

import pytest

from rest_client_gen.dynamic_typing import (DList, DOptional, FloatString, IntString, ModelMeta, compile_imports)
from rest_client_gen.models import sort_fields
from rest_client_gen.models.attr import AttrsModelCodeGenerator, METADATA_FIELD_NAME, sort_kwargs
from rest_client_gen.models.base import generate_code
from test.test_code_generation.test_models_code_generator import model_factory, trim


def test_attrib_kwargs_sort():
sorted_kwargs = sort_kwargs(dict(
y=2,
metadata='b',
converter='a',
default=None,
x=1,
))
expected = ['default', 'converter', 'y', 'x', 'metadata']
for k1, k2 in zip(sorted_kwargs.keys(), expected):
assert k1 == k2
try:
sort_kwargs({}, ['wrong_char'])
except ValueError as e:
assert e.args[0].endswith('wrong_char')
else:
assert 0, "XPass"



def field_meta(original_name):
return f"metadata={{'{METADATA_FIELD_NAME}': '{original_name}'}}"


# Data structure:
# pytest.param id -> {
# "model" -> (model_name, model_metadata),
# test_name -> expected, ...
# }
test_data = {
"base": {
"model": ("Test", {
"foo": int,
"bar": int,
"baz": float
}),
"fields_data": {
"foo": {
"name": "foo",
"type": "int",
"body": f"attr.ib({field_meta('foo')})"
},
"bar": {
"name": "bar",
"type": "int",
"body": f"attr.ib({field_meta('bar')})"
},
"baz": {
"name": "baz",
"type": "float",
"body": f"attr.ib({field_meta('baz')})"
}
},
"fields": {
"imports": "",
"fields": [
f"foo: int = attr.ib({field_meta('foo')})",
f"bar: int = attr.ib({field_meta('bar')})",
f"baz: float = attr.ib({field_meta('baz')})",
]
},
"generated": trim(f"""
import attr


@attr.s
class Test:
foo: int = attr.ib({field_meta('foo')})
bar: int = attr.ib({field_meta('bar')})
baz: float = attr.ib({field_meta('baz')})
""")
},
"complex": {
"model": ("Test", {
"foo": int,
"baz": DOptional(DList(DList(str))),
"bar": DOptional(IntString),
"qwerty": FloatString,
"asdfg": DOptional(int)
}),
"fields_data": {
"foo": {
"name": "foo",
"type": "int",
"body": f"attr.ib({field_meta('foo')})"
},
"baz": {
"name": "baz",
"type": "Optional[List[List[str]]]",
"body": f"attr.ib(factory=list, {field_meta('baz')})"
},
"bar": {
"name": "bar",
"type": "Optional[IntString]",
"body": f"attr.ib(default=None, converter=optional(IntString), {field_meta('bar')})"
},
"qwerty": {
"name": "qwerty",
"type": "FloatString",
"body": f"attr.ib(converter=FloatString, {field_meta('qwerty')})"
},
"asdfg": {
"name": "asdfg",
"type": "Optional[int]",
"body": f"attr.ib(default=None, {field_meta('asdfg')})"
}
},
"generated": trim(f"""
import attr
from attr.converter import optional
from rest_client_gen.dynamic_typing.string_serializable import FloatString, IntString
from typing import List, Optional


@attr.s
class Test:
foo: int = attr.ib({field_meta('foo')})
qwerty: FloatString = attr.ib(converter=FloatString, {field_meta('qwerty')})
baz: Optional[List[List[str]]] = attr.ib(factory=list, {field_meta('baz')})
bar: Optional[IntString] = attr.ib(default=None, converter=optional(IntString), {field_meta('bar')})
asdfg: Optional[int] = attr.ib(default=None, {field_meta('asdfg')})
""")
}
}

test_data_unzip = {
test: [
pytest.param(
model_factory(*data["model"]),
data[test],
id=id
)
for id, data in test_data.items()
if test in data
]
for test in ("fields_data", "fields", "generated")
}


@pytest.mark.parametrize("value,expected", test_data_unzip["fields_data"])
def test_fields_data_attr(value: ModelMeta, expected: Dict[str, dict]):
gen = AttrsModelCodeGenerator(value)
required, optional = sort_fields(value)
for is_optional, fields in enumerate((required, optional)):
for field in fields:
field_imports, data = gen.field_data(field, value.type[field], bool(is_optional))
assert data == expected[field]


@pytest.mark.parametrize("value,expected", test_data_unzip["fields"])
def test_fields_attr(value: ModelMeta, expected: dict):
expected_imports: str = expected["imports"]
expected_fields: List[str] = expected["fields"]
gen = AttrsModelCodeGenerator(value)
imports, fields = gen.fields
imports = compile_imports(imports)
assert imports == expected_imports
assert fields == expected_fields


@pytest.mark.parametrize("value,expected", test_data_unzip["generated"])
def test_generated_attr(value: ModelMeta, expected: str):
generated = generate_code(([{"model": value, "nested": []}], {}), AttrsModelCodeGenerator)
assert generated.rstrip() == expected, generated
Loading