diff --git a/HISTORY.md b/HISTORY.md index d577fc23..f020277b 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -10,6 +10,7 @@ ([#410](https://github.com/python-attrs/cattrs/issues/410) [#411](https://github.com/python-attrs/cattrs/pull/411)) - Introduce the `use_class_methods` strategy. Learn more [here](https://catt.rs/en/latest/strategies.html#using-class-specific-structure-and-unstructure-methods). ([#405](https://github.com/python-attrs/cattrs/pull/405)) +- Implement the `union passthrough` strategy, enabling much richer union handling for preconfigured converters. [Learn more here](https://catt.rs/en/stable/strategies.html#union-passthrough). - The `omit` parameter of {py:func}`cattrs.override` is now of type `bool | None` (from `bool`). `None` is the new default and means to apply default _cattrs_ handling to the attribute, which is to omit the attribute if it's marked as `init=False`, and keep it otherwise. - Fix {py:func}`format_exception() ` parameter working for recursive calls to {py:func}`transform_error `. @@ -40,6 +41,8 @@ ([#418](https://github.com/python-attrs/cattrs/issues/418)) - Add support for `date` to preconfigured converters. ([#420](https://github.com/python-attrs/cattrs/pull/420)) +- Add support for `datetime.date`s to the PyYAML preconfigured converter. + ([#393](https://github.com/python-attrs/cattrs/issues/393)) ## 23.1.2 (2023-06-02) diff --git a/docs/_static/custom.css b/docs/_static/custom.css index cdfdc6c5..26ec87e4 100644 --- a/docs/_static/custom.css +++ b/docs/_static/custom.css @@ -101,6 +101,7 @@ span:target ~ h4:first-of-type, span:target ~ h5:first-of-type, span:target ~ h6:first-of-type { text-decoration: underline dashed; + text-decoration-thickness: 1px; } div.article-container > article { diff --git a/docs/strategies.md b/docs/strategies.md index 4c312cc1..1d888c64 100644 --- a/docs/strategies.md +++ b/docs/strategies.md @@ -323,3 +323,62 @@ Nested(m=MyClass(a=43)) ```{versionadded} 23.2.0 ``` + +## Union Passthrough + +_Found at {py:func}`cattrs.strategies.configure_union_passthrough`._ + +The _union passthrough_ strategy enables a {py:class}`Converter ` to structure unions and subunions of given types. + +A very common use case for _cattrs_ is processing data created by other serialization libraries, such as _JSON_ or _msgpack_. +These libraries are able to directly produce values of unions inherent to the format. +For example, every JSON library can differentiate between numbers, booleans, strings and null values since these values are represented differently in the wire format. +This strategy enables _cattrs_ to offload the creation of these values to an underlying library and just validate the final value. +So, _cattrs_ preconfigured JSON converters can handle the following type: + +- `bool | int | float | str | None` + +Continuing the JSON example, this strategy also enables structuring subsets of unions of these values. +Accordingly, here are some examples of subset unions that are also supported: + +- `bool | int` +- `int | str` +- `int | float | str` + +The strategy also supports types including one or more [Literals](https://mypy.readthedocs.io/en/stable/literal_types.html#literal-types) of supported types. For example: + +- `Literal["admin", "user"] | int` +- `Literal[True] | str | int | float` + +The strategy also supports [NewTypes](https://mypy.readthedocs.io/en/stable/more_types.html#newtypes) of these types. For example: + +```python +>>> from typing import NewType + +>>> UserId = NewType("UserId", int) + +>>> converter.loads("12", UserId) +12 +``` + +Unions containing unsupported types can be handled if at least one union type is supported by the strategy; the supported union types will be checked before the rest (referred to as the _spillover_) is handed over to the converter again. + +For example, if `A` and `B` are arbitrary _attrs_ classes, the union `Literal[10] | A | B` cannot be handled directly by a JSON converter. +However, the strategy will check if the value being structured matches `Literal[10]` (because this type _is_ supported) and, if not, will pass it back to the converter to be structured as `A | B` (where a different strategy can handle it). + +The strategy is designed to run in _O(1)_ at structure time; it doesn't depend on the size of the union and the ordering of union members. + +This strategy has been preapplied to the following preconfigured converters: + +- {py:class}`BsonConverter ` +- {py:class}`Cbor2Converter ` +- {py:class}`JsonConverter ` +- {py:class}`MsgpackConverter ` +- {py:class}`OrjsonConverter ` +- {py:class}`PyyamlConverter ` +- {py:class}`TomlkitConverter ` +- {py:class}`UjsonConverter ` + +```{versionadded} 23.2.0 + +``` diff --git a/pdm.lock b/pdm.lock index b39558a7..f2c96dee 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,11 +2,11 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "bson", "cbor2", "docs", "lint", "msgpack", "orjson", "pyyaml", "test", "tomlkit", "ujson"] +groups = ["default", "bson", "cbor2", "docs", "lint", "msgpack", "orjson", "pyyaml", "test", "tomlkit", "ujson", "bench"] cross_platform = true static_urls = false lock_version = "4.3" -content_hash = "sha256:0a4110571e06ea2153a3ae2359183ced5329906bafcdb9e79cbddff75da3fdf5" +content_hash = "sha256:3a391fc210d959b9ecc1ee3628620213d11c06105442911303b370a0243fb5a3" [[package]] name = "alabaster" @@ -836,6 +836,21 @@ files = [ {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, ] +[[package]] +name = "psutil" +version = "5.9.5" +requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +summary = "Cross-platform lib for process and system monitoring in Python." +files = [ + {file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"}, + {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"}, + {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"}, + {file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"}, + {file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"}, + {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"}, + {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"}, +] + [[package]] name = "py-cpuinfo" version = "9.0.0" @@ -940,6 +955,19 @@ files = [ {file = "pymongo-4.4.0.tar.gz", hash = "sha256:a1b5d286fee4b9b5a0312faede02f2ce2f56ac695685af1d25f428abdac9a22c"}, ] +[[package]] +name = "pyperf" +version = "2.6.1" +requires_python = ">=3.7" +summary = "Python module to run and analyze benchmarks" +dependencies = [ + "psutil>=5.9.0", +] +files = [ + {file = "pyperf-2.6.1-py3-none-any.whl", hash = "sha256:9f81bf78335428ddf9845f1388dfb56181e744a69e93d8506697a56dc67b6d5f"}, + {file = "pyperf-2.6.1.tar.gz", hash = "sha256:171aea69b8efde61210e512166d8764e7765a9c7678b768052174b01f349f247"}, +] + [[package]] name = "pytest" version = "7.4.0" diff --git a/pyproject.toml b/pyproject.toml index feb5e697..fbae29c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,32 @@ known_first_party = ["cattr"] [tool.hatch.build.targets.wheel] packages = ["src/cattr", "src/cattrs"] + +[tool.pdm.dev-dependencies] +lint = [ + "isort>=5.11.5", + "black>=23.3.0", + "ruff>=0.0.277", +] +test = [ + "hypothesis>=6.79.4", + "pytest>=7.4.0", + "pytest-benchmark>=4.0.0", + "immutables>=0.19", + "typing-extensions>=4.7.1", + "coverage>=7.2.7", +] +docs = [ + "sphinx>=5.3.0", + "furo>=2023.3.27", + "sphinx-copybutton>=0.5.2", + "myst-parser>=1.0.0", + "pendulum>=2.1.2", +] +bench = [ + "pyperf>=2.6.1", +] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" @@ -62,28 +88,6 @@ bson = [ "pymongo>=4.4.0", ] -[tool.pdm.dev-dependencies] -lint = [ - "isort>=5.11.5", - "black>=23.3.0", - "ruff>=0.0.277", -] -test = [ - "hypothesis>=6.79.4", - "pytest>=7.4.0", - "pytest-benchmark>=4.0.0", - "immutables>=0.19", - "typing-extensions>=4.7.1", - "coverage>=7.2.7", -] -docs = [ - "sphinx>=5.3.0", - "furo>=2023.3.27", - "sphinx-copybutton>=0.5.2", - "myst-parser>=1.0.0", - "pendulum>=2.1.2", -] - [tool.pytest.ini_options] addopts = "-l --benchmark-sort=fullname --benchmark-warmup=true --benchmark-warmup-iterations=5 --benchmark-group-by=fullname" @@ -111,6 +115,7 @@ select = [ "B", # flake8-bugbear "C4", # flake8-comprehensions "T10", # flake8-debugger + "T20", # flake8-print "ISC", # flake8-implicit-str-concat "RET", # flake8-return "SIM", # flake8-simplify diff --git a/src/cattrs/preconf/bson.py b/src/cattrs/preconf/bson.py index 618a4907..c7a6a4e1 100644 --- a/src/cattrs/preconf/bson.py +++ b/src/cattrs/preconf/bson.py @@ -1,14 +1,15 @@ """Preconfigured converters for bson.""" from base64 import b85decode, b85encode -from datetime import datetime, date -from typing import Any, Type, TypeVar +from datetime import date, datetime +from typing import Any, Type, TypeVar, Union -from bson import DEFAULT_CODEC_OPTIONS, CodecOptions, ObjectId, decode, encode +from bson import DEFAULT_CODEC_OPTIONS, CodecOptions, Int64, ObjectId, decode, encode from cattrs._compat import AbstractSet, is_mapping from cattrs.gen import make_mapping_structure_fn from ..converters import BaseConverter, Converter +from ..strategies import configure_union_passthrough from . import validate_datetime T = TypeVar("T") @@ -83,6 +84,9 @@ def gen_structure_mapping(cl: Any): ) converter.register_structure_hook(ObjectId, lambda v, _: ObjectId(v)) + configure_union_passthrough( + Union[str, bool, int, float, None, bytes, datetime, ObjectId, Int64], converter + ) # datetime inherits from date, so identity unstructure hook used # here to prevent the date unstructure hook running. diff --git a/src/cattrs/preconf/cbor2.py b/src/cattrs/preconf/cbor2.py index 11756f04..444014b4 100644 --- a/src/cattrs/preconf/cbor2.py +++ b/src/cattrs/preconf/cbor2.py @@ -1,12 +1,13 @@ """Preconfigured converters for cbor2.""" -from datetime import datetime, timezone, date -from typing import Any, Type, TypeVar +from datetime import date, datetime, timezone +from typing import Any, Type, TypeVar, Union from cbor2 import dumps, loads from cattrs._compat import AbstractSet from ..converters import BaseConverter, Converter +from ..strategies import configure_union_passthrough T = TypeVar("T") @@ -32,6 +33,7 @@ def configure_converter(converter: BaseConverter): ) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + configure_union_passthrough(Union[str, bool, int, float, None, bytes], converter) def make_converter(*args: Any, **kwargs: Any) -> Cbor2Converter: diff --git a/src/cattrs/preconf/json.py b/src/cattrs/preconf/json.py index 61abc365..e4d52a3c 100644 --- a/src/cattrs/preconf/json.py +++ b/src/cattrs/preconf/json.py @@ -1,12 +1,13 @@ """Preconfigured converters for the stdlib json.""" from base64 import b85decode, b85encode -from datetime import datetime, date +from datetime import date, datetime from json import dumps, loads from typing import Any, Type, TypeVar, Union from cattrs._compat import AbstractSet, Counter from ..converters import BaseConverter, Converter +from ..strategies import configure_union_passthrough T = TypeVar("T") @@ -36,6 +37,7 @@ def configure_converter(converter: BaseConverter): converter.register_structure_hook(datetime, lambda v, _: datetime.fromisoformat(v)) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + configure_union_passthrough(Union[str, bool, int, float, None, bytes], converter) def make_converter(*args: Any, **kwargs: Any) -> JsonConverter: diff --git a/src/cattrs/preconf/msgpack.py b/src/cattrs/preconf/msgpack.py index eb13b6e6..2e7470b6 100644 --- a/src/cattrs/preconf/msgpack.py +++ b/src/cattrs/preconf/msgpack.py @@ -1,12 +1,13 @@ """Preconfigured converters for msgpack.""" -from datetime import datetime, timezone, date, time -from typing import Any, Type, TypeVar +from datetime import date, datetime, time, timezone +from typing import Any, Type, TypeVar, Union from msgpack import dumps, loads from cattrs._compat import AbstractSet from ..converters import BaseConverter, Converter +from ..strategies import configure_union_passthrough T = TypeVar("T") @@ -36,6 +37,7 @@ def configure_converter(converter: BaseConverter): converter.register_structure_hook( date, lambda v, _: datetime.fromtimestamp(v, timezone.utc).date() ) + configure_union_passthrough(Union[str, bool, int, float, None, bytes], converter) def make_converter(*args: Any, **kwargs: Any) -> MsgpackConverter: diff --git a/src/cattrs/preconf/orjson.py b/src/cattrs/preconf/orjson.py index 0be83049..0b4f32de 100644 --- a/src/cattrs/preconf/orjson.py +++ b/src/cattrs/preconf/orjson.py @@ -9,6 +9,7 @@ from cattrs._compat import AbstractSet, is_mapping from ..converters import BaseConverter, Converter +from ..strategies import configure_union_passthrough T = TypeVar("T") @@ -66,6 +67,7 @@ def key_handler(v): converter._unstructure_func.register_func_list( [(is_mapping, gen_unstructure_mapping, True)] ) + configure_union_passthrough(Union[str, bool, int, float, None], converter) def make_converter(*args: Any, **kwargs: Any) -> OrjsonConverter: diff --git a/src/cattrs/preconf/pyyaml.py b/src/cattrs/preconf/pyyaml.py index 5de6e9cf..091c1d37 100644 --- a/src/cattrs/preconf/pyyaml.py +++ b/src/cattrs/preconf/pyyaml.py @@ -1,17 +1,24 @@ """Preconfigured converters for pyyaml.""" -from datetime import datetime, date -from typing import Any, Type, TypeVar +from datetime import date, datetime +from typing import Any, Type, TypeVar, Union from yaml import safe_dump, safe_load from cattrs._compat import FrozenSetSubscriptable from ..converters import BaseConverter, Converter +from ..strategies import configure_union_passthrough from . import validate_datetime T = TypeVar("T") +def validate_date(v, _): + if not isinstance(v, date): + raise ValueError(f"Expected date, got {v}") + return v + + class PyyamlConverter(Converter): def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> str: return safe_dump(self.unstructure(obj, unstructure_as=unstructure_as), **kwargs) @@ -26,6 +33,7 @@ def configure_converter(converter: BaseConverter): * frozensets are serialized as lists * string enums are converted into strings explicitly + * datetimes and dates are validated """ converter.register_unstructure_hook( str, lambda v: v if v.__class__ is str else v.value @@ -35,8 +43,10 @@ def configure_converter(converter: BaseConverter): # here to prevent the date unstructure hook running. converter.register_unstructure_hook(datetime, lambda v: v) converter.register_structure_hook(datetime, validate_datetime) - converter.register_unstructure_hook(date, lambda v: v.isoformat()) - converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + converter.register_structure_hook(date, validate_date) + configure_union_passthrough( + Union[str, bool, int, float, None, bytes, datetime, date], converter + ) def make_converter(*args: Any, **kwargs: Any) -> PyyamlConverter: diff --git a/src/cattrs/preconf/tomlkit.py b/src/cattrs/preconf/tomlkit.py index 5ee8d1c0..8cdfeac7 100644 --- a/src/cattrs/preconf/tomlkit.py +++ b/src/cattrs/preconf/tomlkit.py @@ -1,15 +1,17 @@ """Preconfigured converters for tomlkit.""" from base64 import b85decode, b85encode -from datetime import datetime, date +from datetime import date, datetime from enum import Enum from operator import attrgetter -from typing import Any, Type, TypeVar +from typing import Any, Type, TypeVar, Union from tomlkit import dumps, loads +from tomlkit.items import Float, Integer, String from cattrs._compat import AbstractSet, is_mapping from ..converters import BaseConverter, Converter +from ..strategies import configure_union_passthrough from . import validate_datetime T = TypeVar("T") @@ -66,6 +68,9 @@ def key_handler(k: bytes): converter.register_structure_hook(datetime, validate_datetime) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + configure_union_passthrough( + Union[str, String, bool, int, Integer, float, Float], converter + ) def make_converter(*args: Any, **kwargs: Any) -> TomlkitConverter: diff --git a/src/cattrs/preconf/ujson.py b/src/cattrs/preconf/ujson.py index d48abb43..b6de8e85 100644 --- a/src/cattrs/preconf/ujson.py +++ b/src/cattrs/preconf/ujson.py @@ -1,13 +1,14 @@ """Preconfigured converters for ujson.""" from base64 import b85decode, b85encode -from datetime import datetime, date -from typing import Any, AnyStr, Type, TypeVar +from datetime import date, datetime +from typing import Any, AnyStr, Type, TypeVar, Union from ujson import dumps, loads from cattrs._compat import AbstractSet from ..converters import BaseConverter, Converter +from ..strategies import configure_union_passthrough T = TypeVar("T") @@ -37,6 +38,7 @@ def configure_converter(converter: BaseConverter): converter.register_structure_hook(datetime, lambda v, _: datetime.fromisoformat(v)) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + configure_union_passthrough(Union[str, bool, int, float, None], converter) def make_converter(*args: Any, **kwargs: Any) -> UjsonConverter: diff --git a/src/cattrs/strategies/__init__.py b/src/cattrs/strategies/__init__.py index 563caa06..c2fe4fb7 100644 --- a/src/cattrs/strategies/__init__.py +++ b/src/cattrs/strategies/__init__.py @@ -1,6 +1,11 @@ """High level strategies for converters.""" from ._class_methods import use_class_methods from ._subclasses import include_subclasses -from ._unions import configure_tagged_union +from ._unions import configure_tagged_union, configure_union_passthrough -__all__ = ["configure_tagged_union", "include_subclasses", "use_class_methods"] +__all__ = [ + "configure_tagged_union", + "configure_union_passthrough", + "include_subclasses", + "use_class_methods", +] diff --git a/src/cattrs/strategies/_unions.py b/src/cattrs/strategies/_unions.py index b40fbbe0..bc681e45 100644 --- a/src/cattrs/strategies/_unions.py +++ b/src/cattrs/strategies/_unions.py @@ -1,11 +1,16 @@ from collections import defaultdict -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Type, Union from attrs import NOTHING -from cattrs import Converter +from cattrs import BaseConverter, Converter +from cattrs._compat import get_newtype_base, is_literal, is_subclass, is_union_type -__all__ = ["default_tag_generator", "configure_tagged_union"] +__all__ = [ + "default_tag_generator", + "configure_tagged_union", + "configure_union_passthrough", +] def default_tag_generator(typ: Type) -> str: @@ -101,3 +106,128 @@ def structure_tagged_union( converter.register_unstructure_hook(union, unstructure_tagged_union) converter.register_structure_hook(union, structure_tagged_union) + + +def configure_union_passthrough(union: Any, converter: BaseConverter) -> None: + """ + Configure the converter to support validating and passing through unions of the + provided types and their subsets. + + For example, all mature JSON libraries natively support producing unions of ints, + floats, Nones, and strings. Using this strategy, a converter can be configured + to efficiently validate and pass through unions containing these types. + + The most important point is that another library (in this example the JSON + library) handles producing the union, and the converter is configured to just + validate it. + + Literals of provided types are also supported, and are checked by value. + + NewTypes of provided types are also supported. + + The strategy is designed to be O(1) in execution time, and independent of the + ordering of types in the union. + + If the union contains a class and one or more of its subclasses, the subclasses + will also be included when validating the superclass. + + .. versionadded:: 23.2.0 + """ + args = set(union.__args__) + + def make_structure_native_union(exact_type: Any) -> Callable: + # `exact_type` is likely to be a subset of the entire configured union (`args`). + literal_values = { + v for t in exact_type.__args__ if is_literal(t) for v in t.__args__ + } + + # We have no idea what the actual type of `val` will be, so we can't + # use it blindly with an `in` check since it might not be hashable. + # So we do an additional check when handling literals. + # Note: do no use `literal_values` here, since {0, False} gets reduced to {0} + literal_classes = { + v.__class__ + for t in exact_type.__args__ + if is_literal(t) + for v in t.__args__ + } + + non_literal_classes = { + get_newtype_base(t) or t + for t in exact_type.__args__ + if not is_literal(t) and ((get_newtype_base(t) or t) in args) + } + + # We augment the set of allowed classes with any configured subclasses of + # the exact subclasses. + non_literal_classes |= { + a for a in args if any(is_subclass(a, c) for c in non_literal_classes) + } + + # We check for spillover - union types not handled by the strategy. + # If spillover exists and we fail to validate our types, we call + # further into the converter with the rest. + spillover = { + a + for a in exact_type.__args__ + if (get_newtype_base(a) or a) not in non_literal_classes + and not is_literal(a) + } + + if spillover: + spillover_type = ( + Union[tuple(spillover)] if len(spillover) > 1 else next(iter(spillover)) + ) + + def structure_native_union( + val: Any, + _: Any, + classes=non_literal_classes, + vals=literal_values, + converter=converter, + spillover=spillover_type, + ) -> exact_type: + if val.__class__ in literal_classes and val in vals: + return val + if val.__class__ in classes: + return val + return converter.structure(val, spillover) + + else: + + def structure_native_union( + val: Any, _: Any, classes=non_literal_classes, vals=literal_values + ) -> exact_type: + if val.__class__ in literal_classes and val in vals: + return val + if val.__class__ in classes: + return val + raise TypeError(f"{val} ({val.__class__}) not part of {_}") + + return structure_native_union + + def contains_native_union(exact_type: Any) -> bool: + """Can we handle this type?""" + if is_union_type(exact_type): + type_args = set(exact_type.__args__) + # We special case optionals, since they are very common + # and are handled a little more efficiently by default. + if len(type_args) == 2 and type(None) in type_args: + return False + + literal_classes = { + lit_arg.__class__ + for t in type_args + if is_literal(t) + for lit_arg in t.__args__ + } + non_literal_types = { + get_newtype_base(t) or t for t in type_args if not is_literal(t) + } + + return (literal_classes | non_literal_types) & args + return False + + converter.register_structure_hook_factory( + contains_native_union, make_structure_native_union + ) diff --git a/tests/strategies/test_native_unions.py b/tests/strategies/test_native_unions.py new file mode 100644 index 00000000..6cbb5ac4 --- /dev/null +++ b/tests/strategies/test_native_unions.py @@ -0,0 +1,113 @@ +"""Tests for the native union passthrough strategy. + +Note that a significant amount of test coverage for this is in the +preconf tests. +""" +from typing import List, Optional, Union + +import pytest +from attrs import define + +from cattrs import BaseConverter +from cattrs.strategies import configure_union_passthrough + +from .._compat import is_py37 + + +def test_only_primitives(converter: BaseConverter) -> None: + """A native union with only primitives works.""" + union = Union[int, str, None] + configure_union_passthrough(union, converter) + + assert converter.unstructure(1, union) == 1 + assert converter.structure(1, union) == 1 + assert converter.unstructure("1", union) == "1" + assert converter.structure("1", union) == "1" + assert converter.unstructure(None, union) is None + assert converter.structure(None, union) is None + + with pytest.raises(TypeError): + converter.structure((), union) + + +@pytest.mark.skipif(is_py37, reason="Not supported on 3.7") +def test_literals(converter: BaseConverter) -> None: + """A union with primitives and literals works.""" + from typing import Literal + + union = Union[int, str, None] + exact_type = Union[int, Literal["test"], None] + configure_union_passthrough(union, converter) + + assert converter.unstructure(1, exact_type) == 1 + assert converter.structure(1, exact_type) == 1 + assert converter.unstructure("test", exact_type) == "test" + assert converter.structure("test", exact_type) == "test" + assert converter.unstructure(None, exact_type) is None + assert converter.structure(None, exact_type) is None + + with pytest.raises(TypeError): + converter.structure((), exact_type) + with pytest.raises(TypeError): + converter.structure("t", exact_type) + + +def test_skip_optionals() -> None: + """ + The strategy skips Optionals, since those are more efficiently + handled by default. + """ + c = BaseConverter() + + configure_union_passthrough(Union[int, str, None], c) + + h = c._structure_func.dispatch(Optional[int]) + assert h.__name__ != "structure_native_union" + + +def test_spillover(converter: BaseConverter) -> None: + """Types not covered by the native union are correctly handled.""" + union = Union[int, str, None] + exact_type = Union[int, List[str], None] + + configure_union_passthrough(union, converter) + + assert converter.unstructure(1, exact_type) == 1 + assert converter.structure(1, exact_type) == 1 + + assert converter.unstructure(["a", "b"], exact_type) == ["a", "b"] + assert converter.structure(["a", "b"], exact_type) == ["a", "b"] + + with pytest.raises(TypeError): + converter.structure((), union) + + +def test_multiple_spillover(converter: BaseConverter) -> None: + """Types not covered by the native union are correctly handled.""" + union = Union[int, str, None] + + @define + class A: + a: int + + @define + class B: + b: int + + # A | B will be handled by the default disambiguator. + exact_type = Union[int, List[str], A, B, None] + + configure_union_passthrough(union, converter) + + assert converter.unstructure(1, exact_type) == 1 + assert converter.structure(1, exact_type) == 1 + + assert converter.unstructure(["a", "b"], exact_type) == ["a", "b"] + assert converter.structure(["a", "b"], List[str]) == ["a", "b"] + assert converter.unstructure(A(1), exact_type) == {"a": 1} + assert converter.structure({"a": 1}, A) == A(1) + assert converter.unstructure(B(1), exact_type) == {"b": 1} + assert converter.structure({"b": 1}, B) == B(1) + + with pytest.raises(TypeError): + converter.structure((), union) diff --git a/tests/test_preconf.py b/tests/test_preconf.py index acb13e5e..f547e8de 100644 --- a/tests/test_preconf.py +++ b/tests/test_preconf.py @@ -1,27 +1,32 @@ -from datetime import datetime, timezone, date +import sys +from datetime import date, datetime, timezone from enum import Enum, IntEnum, unique from json import dumps as json_dumps from json import loads as json_loads from platform import python_implementation -from typing import Dict, List +from typing import Any, Dict, List, NewType, Tuple, Union import pytest -from attr import define +from attrs import define from bson import CodecOptions, ObjectId -from hypothesis import given +from hypothesis import given, settings from hypothesis.strategies import ( + DrawFn, binary, booleans, + builds, characters, composite, - datetimes, dates, + datetimes, dictionaries, floats, frozensets, integers, just, lists, + one_of, + sampled_from, sets, text, ) @@ -48,6 +53,16 @@ from cattrs.preconf.ujson import make_converter as ujson_make_converter +@define +class A: + a: int + + +@define +class B: + b: str + + @define class Everything: @unique @@ -80,11 +95,14 @@ class AStringEnum(str, Enum): a_date: date a_string_enum_dict: Dict[AStringEnum, int] a_bytes_dict: Dict[bytes, bytes] + native_union: Union[int, float, str] + native_union_with_spillover: Union[int, str, Set[str]] + native_union_with_union_spillover: Union[int, str, A, B] @composite def everythings( - draw, + draw: DrawFn, min_int=None, max_int=None, allow_inf=True, @@ -119,48 +137,134 @@ def everythings( d.year, d.month, d.day, d.hour, d.minute, d.second, tzinfo=d.tzinfo ) ) + fs = floats(allow_nan=False, allow_infinity=allow_inf) + ints = integers(min_value=min_int, max_value=max_int) + return Everything( draw(strings), draw(binary()), - draw(integers(min_value=min_int, max_value=max_int)), - draw(floats(allow_nan=False, allow_infinity=allow_inf)), - draw(dictionaries(key_text, integers(min_value=min_int, max_value=max_int))), - draw(lists(integers(min_value=min_int, max_value=max_int))), - tuple(draw(lists(integers(min_value=min_int, max_value=max_int)))), - ( - draw(strings), - draw(integers(min_value=min_int, max_value=max_int)), - draw(floats(allow_nan=False, allow_infinity=allow_inf)), - ), - Counter( - draw(dictionaries(key_text, integers(min_value=min_int, max_value=max_int))) - ), - draw( - dictionaries( - integers(min_value=min_int, max_value=max_int), - floats(allow_nan=False, allow_infinity=allow_inf), - ) - ), - draw(dictionaries(floats(allow_nan=False, allow_infinity=allow_inf), strings)), - draw(lists(floats(allow_nan=False, allow_infinity=allow_inf))), + draw(ints), + draw(fs), + draw(dictionaries(key_text, ints)), + draw(lists(ints)), + tuple(draw(lists(ints))), + (draw(strings), draw(ints), draw(fs)), + Counter(draw(dictionaries(key_text, ints))), + draw(dictionaries(ints, fs)), + draw(dictionaries(fs, strings)), + draw(lists(fs)), draw(lists(strings)), - draw(sets(floats(allow_nan=False, allow_infinity=allow_inf))), - draw(sets(integers(min_value=min_int, max_value=max_int))), + draw(sets(fs)), + draw(sets(ints)), draw(frozensets(strings)), Everything.AnIntEnum.A, Everything.AStringEnum.A, draw(dts), draw(dates(min_value=date(1970, 1, 1), max_value=date(2038, 1, 1))), - draw( - dictionaries( - just(Everything.AStringEnum.A), - integers(min_value=min_int, max_value=max_int), - ) - ), + draw(dictionaries(just(Everything.AStringEnum.A), ints)), draw(dictionaries(binary(min_size=min_key_length), binary())), + draw(one_of(ints, fs, strings)), + draw(one_of(ints, strings, sets(strings))), + draw(one_of(ints, strings, ints.map(A), strings.map(B))), ) +NewStr = NewType("NewStr", str) +NewInt = NewType("NewInt", int) +NewBool = NewType("NewBool", bool) + + +@composite +def native_unions( + draw: DrawFn, + include_strings=True, + include_bools=True, + include_ints=True, + include_floats=True, + include_nones=True, + include_bytes=True, + include_datetimes=True, + include_objectids=False, + include_literals=True, +) -> Tuple[Any, Any]: + types = [] + strats = {} + if include_strings: + types.append(str) + strats[str] = text() + if include_bools: + types.append(bool) + strats[bool] = booleans() + if include_ints: + types.append(int) + strats[int] = integers() + if include_floats: + types.append(float) + strats[float] = floats(allow_nan=False) + if include_nones: + types.append(None) + strats[None] = just(None) + if include_bytes: + types.append(bytes) + strats[bytes] = binary() + if include_datetimes: + types.append(datetime) + strats[datetime] = datetimes( + min_value=datetime(1970, 1, 1), max_value=datetime(2038, 1, 1) + ) + if include_objectids: + types.append(ObjectId) + strats[ObjectId] = builds(ObjectId) + + chosen_types = draw(sets(sampled_from(types), min_size=2)) + + if include_literals: + from typing import Literal + + # We can replace some of the types with 1+ literal types. + if str in chosen_types: + strat = draw(sampled_from(["leave", "literal", "newtype"])) + if strat == "literal": + chosen_types.remove(str) + vals = draw(sets(text(), min_size=1, max_size=2)) + for lit in vals: + t = Literal[lit] + chosen_types.add(t) + strats[t] = just(lit) + elif strat == "newtype": + chosen_types.remove(str) + chosen_types.add(NewStr) + strats[NewStr] = strats.pop(str) + if bool in chosen_types: + strat = draw(sampled_from(["leave", "literal", "newtype"])) + if strat == "literal": + chosen_types.remove(bool) + val = draw(booleans()) + t = Literal[val] + chosen_types.add(t) + strats[t] = just(val) + elif strat == "newtype": + chosen_types.remove(bool) + chosen_types.add(NewBool) + strats[NewBool] = strats.pop(bool) + if int in chosen_types: + strat = draw(sampled_from(["leave", "literal", "newtype"])) + if strat == "literal": + chosen_types.remove(int) + vals = draw(sets(integers(), min_size=1, max_size=2)) + for val in vals: + t = Literal[val] + chosen_types.add(t) + strats[t] = just(val) + elif strat == "newtype": + # NewTypes instead. + chosen_types.remove(int) + chosen_types.add(NewInt) + strats[NewInt] = strats.pop(int) + + return Union[tuple(chosen_types)], draw(one_of(*[strats[t] for t in chosen_types])) + + @given(everythings()) def test_stdlib_json(everything: Everything): converter = json_make_converter() @@ -187,6 +291,46 @@ def test_stdlib_json_converter_unstruct_collection_overrides(everything: Everyth assert raw["a_frozenset"] == sorted(raw["a_frozenset"]) +@given( + union_and_val=native_unions( + include_bytes=False, + include_datetimes=False, + include_bools=sys.version_info[:2] != (3, 8), # Literal issues on 3.8 + include_literals=sys.version_info >= (3, 8), + ), + detailed_validation=..., +) +@settings(max_examples=1000) +def test_stdlib_json_unions(union_and_val: tuple, detailed_validation: bool): + """Native union passthrough works.""" + converter = json_make_converter(detailed_validation=detailed_validation) + type, val = union_and_val + + assert converter.structure(val, type) == val + + +@given( + union_and_val=native_unions( + include_strings=False, + include_bytes=False, + include_bools=sys.version_info[:2] != (3, 8), # Literal issues on 3.8 + include_literals=sys.version_info >= (3, 8), + ), + detailed_validation=..., +) +def test_stdlib_json_unions_with_spillover( + union_and_val: tuple, detailed_validation: bool +): + """Native union passthrough works and can handle spillover. + + The stdlib json converter cannot handle datetimes natively. + """ + converter = json_make_converter(detailed_validation=detailed_validation) + type, val = union_and_val + + assert converter.structure(converter.unstructure(val), type) == val + + @given( everythings( min_int=-9223372036854775808, max_int=9223372036854775807, allow_inf=False @@ -227,6 +371,23 @@ def test_ujson_converter_unstruct_collection_overrides(everything: Everything): assert raw["a_frozenset"] == sorted(raw["a_frozenset"]) +@given( + union_and_val=native_unions( + include_bytes=False, + include_datetimes=False, + include_bools=sys.version_info[:2] != (3, 8), # Literal issues on 3.8 + include_literals=sys.version_info >= (3, 8), + ), + detailed_validation=..., +) +def test_ujson_unions(union_and_val: tuple, detailed_validation: bool): + """Native union passthrough works.""" + converter = ujson_make_converter(detailed_validation=detailed_validation) + type, val = union_and_val + + assert converter.structure(val, type) == val + + @pytest.mark.skipif(python_implementation() == "PyPy", reason="no orjson on PyPy") @given( everythings( @@ -278,6 +439,26 @@ def test_orjson_converter_unstruct_collection_overrides(everything: Everything): assert raw["a_frozenset"] == sorted(raw["a_frozenset"]) +@pytest.mark.skipif(python_implementation() == "PyPy", reason="no orjson on PyPy") +@given( + union_and_val=native_unions( + include_bytes=False, + include_datetimes=False, + include_bools=sys.version_info[:2] != (3, 8), # Literal issues on 3.8 + include_literals=sys.version_info >= (3, 8), + ), + detailed_validation=..., +) +def test_orjson_unions(union_and_val: tuple, detailed_validation: bool): + """Native union passthrough works.""" + from cattrs.preconf.orjson import make_converter as orjson_make_converter + + converter = orjson_make_converter(detailed_validation=detailed_validation) + type, val = union_and_val + + assert converter.structure(val, type) == val + + @given(everythings(min_int=-9223372036854775808, max_int=18446744073709551615)) def test_msgpack(everything: Everything): from msgpack import dumps as msgpack_dumps @@ -309,6 +490,22 @@ def test_msgpack_converter_unstruct_collection_overrides(everything: Everything) assert raw["a_frozenset"] == sorted(raw["a_frozenset"]) +@given( + union_and_val=native_unions( + include_datetimes=False, + include_bools=sys.version_info[:2] != (3, 8), # Literal issues on 3.8 + include_literals=sys.version_info >= (3, 8), + ), + detailed_validation=..., +) +def test_msgpack_unions(union_and_val: tuple, detailed_validation: bool): + """Native union passthrough works.""" + converter = msgpack_make_converter(detailed_validation=detailed_validation) + type, val = union_and_val + + assert converter.structure(val, type) == val + + @given( everythings( min_int=-9223372036854775808, @@ -368,6 +565,22 @@ def test_bson_converter_unstruct_collection_overrides(everything: Everything): assert raw["a_frozenset"] == sorted(raw["a_frozenset"]) +@given( + union_and_val=native_unions( + include_objectids=True, + include_bools=sys.version_info[:2] != (3, 8), # Literal issues on 3.8 + include_literals=sys.version_info >= (3, 8), + ), + detailed_validation=..., +) +def test_bson_unions(union_and_val: tuple, detailed_validation: bool): + """Native union passthrough works.""" + converter = bson_make_converter(detailed_validation=detailed_validation) + type, val = union_and_val + + assert converter.structure(val, type) == val + + @given(everythings()) def test_pyyaml(everything: Everything): from yaml import safe_dump, safe_load @@ -394,6 +607,39 @@ def test_pyyaml_converter_unstruct_collection_overrides(everything: Everything): assert raw["a_frozenset"] == sorted(raw["a_frozenset"]) +@given( + union_and_val=native_unions( + include_bools=sys.version_info[:2] != (3, 8), # Literal issues on 3.8 + include_literals=sys.version_info >= (3, 8), + ), + detailed_validation=..., +) +def test_pyyaml_unions(union_and_val: tuple, detailed_validation: bool): + """Native union passthrough works.""" + converter = pyyaml_make_converter(detailed_validation=detailed_validation) + type, val = union_and_val + + assert converter.structure(val, type) == val + + +@given(detailed_validation=...) +def test_pyyaml_dates(detailed_validation: bool): + """Pyyaml dates work.""" + converter = pyyaml_make_converter(detailed_validation=detailed_validation) + + @define + class A: + datetime: datetime + date: date + + data = """ + datetime: 1970-01-01T00:00:00Z + date: 1970-01-01""" + assert converter.loads(data, A) == A( + datetime(1970, 1, 1, tzinfo=timezone.utc), date(1970, 1, 1) + ) + + @given( everythings( min_key_length=1, @@ -446,6 +692,24 @@ def test_tomlkit_converter_unstruct_collection_overrides(everything: Everything) assert raw["a_frozenset"] == sorted(raw["a_frozenset"]) +@given( + union_and_val=native_unions( + include_nones=False, + include_bytes=False, + include_datetimes=False, + include_bools=sys.version_info[:2] != (3, 8), # Literal issues on 3.8 + include_literals=sys.version_info >= (3, 8), + ), + detailed_validation=..., +) +def test_tomlkit_unions(union_and_val: tuple, detailed_validation: bool): + """Native union passthrough works.""" + converter = tomlkit_make_converter(detailed_validation=detailed_validation) + type, val = union_and_val + + assert converter.structure(val, type) == val + + def test_bson_objectid(): """BSON ObjectIds are supported by default.""" converter = bson_make_converter() @@ -480,3 +744,19 @@ def test_cbor2_converter_unstruct_collection_overrides(everything: Everything): assert raw["a_set"] == sorted(raw["a_set"]) assert raw["a_mutable_set"] == sorted(raw["a_mutable_set"]) assert raw["a_frozenset"] == sorted(raw["a_frozenset"]) + + +@given( + union_and_val=native_unions( + include_datetimes=False, + include_bools=sys.version_info[:2] != (3, 8), # Literal issues on 3.8 + include_literals=sys.version_info >= (3, 8), + ), + detailed_validation=..., +) +def test_cbor2_unions(union_and_val: tuple, detailed_validation: bool): + """Native union passthrough works.""" + converter = cbor2_make_converter(detailed_validation=detailed_validation) + type, val = union_and_val + + assert converter.structure(val, type) == val