From 735446d062d8648959a55acdaeea7bd5ded3f435 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tin=20Tvrtkovi=C4=87?= Date: Fri, 8 Nov 2024 21:05:20 +0100 Subject: [PATCH 1/3] Fix unstructuring literals with enums --- HISTORY.md | 1 + src/cattrs/_compat.py | 3 ++- src/cattrs/converters.py | 6 ++---- src/cattrs/literals.py | 11 +++++++++++ tests/test_literals.py | 19 +++++++++++++++++++ 5 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 src/cattrs/literals.py create mode 100644 tests/test_literals.py diff --git a/HISTORY.md b/HISTORY.md index 64f4425d..d81df8ab 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -22,6 +22,7 @@ Our backwards-compatibility policy can be found [here](https://github.com/python - Some `defaultdicts` are now [supported by default](https://catt.rs/en/latest/defaulthooks.html#defaultdicts), and {func}`cattrs.cols.is_defaultdict`{func} and `cattrs.cols.defaultdict_structure_factory` are exposed through {mod}`cattrs.cols`. ([#519](https://github.com/python-attrs/cattrs/issues/519) [#588](https://github.com/python-attrs/cattrs/pull/588)) +- Literals containing enums are now unstructured properly. - Replace `cattrs.gen.MappingStructureFn` with `cattrs.SimpleStructureHook[In, T]`. - Python 3.13 is now supported. ([#543](https://github.com/python-attrs/cattrs/pull/543) [#547](https://github.com/python-attrs/cattrs/issues/547)) diff --git a/src/cattrs/_compat.py b/src/cattrs/_compat.py index bc20b2dc..b691a7e7 100644 --- a/src/cattrs/_compat.py +++ b/src/cattrs/_compat.py @@ -236,7 +236,8 @@ def get_final_base(type) -> Optional[type]: # Not present on 3.9.0, so we try carefully. from typing import _LiteralGenericAlias - def is_literal(type) -> bool: + def is_literal(type: Any) -> bool: + """Is this a literal?""" return type in LITERALS or ( isinstance( type, (_GenericAlias, _LiteralGenericAlias, _SpecialGenericAlias) diff --git a/src/cattrs/converters.py b/src/cattrs/converters.py index 59764eb6..3644786c 100644 --- a/src/cattrs/converters.py +++ b/src/cattrs/converters.py @@ -91,6 +91,7 @@ ) from .gen.typeddicts import make_dict_structure_fn as make_typeddict_dict_struct_fn from .gen.typeddicts import make_dict_unstructure_fn as make_typeddict_dict_unstruct_fn +from .literals import is_literal_containing_enums from .types import SimpleStructureHook __all__ = ["UnstructureStrategy", "BaseConverter", "Converter", "GenConverter"] @@ -146,10 +147,6 @@ class UnstructureStrategy(Enum): AS_TUPLE = "astuple" -def is_literal_containing_enums(typ: type) -> bool: - return is_literal(typ) and any(isinstance(val, Enum) for val in typ.__args__) - - def _is_extended_factory(factory: Callable) -> bool: """Does this factory also accept a converter arg?""" # We use the original `inspect.signature` to not evaluate string @@ -238,6 +235,7 @@ def __init__( lambda t: self.get_unstructure_hook(get_type_alias_base(t)), True, ), + (is_literal_containing_enums, self.unstructure), (is_mapping, self._unstructure_mapping), (is_sequence, self._unstructure_seq), (is_mutable_set, self._unstructure_seq), diff --git a/src/cattrs/literals.py b/src/cattrs/literals.py new file mode 100644 index 00000000..badeddaf --- /dev/null +++ b/src/cattrs/literals.py @@ -0,0 +1,11 @@ +from enum import Enum +from typing import Any + +from ._compat import is_literal + +__all__ = ["is_literal", "is_literal_containing_enums"] + + +def is_literal_containing_enums(type: Any) -> bool: + """Is this a literal containing at least one Enum?""" + return is_literal(type) and any(isinstance(val, Enum) for val in type.__args__) diff --git a/tests/test_literals.py b/tests/test_literals.py new file mode 100644 index 00000000..37317177 --- /dev/null +++ b/tests/test_literals.py @@ -0,0 +1,19 @@ +from enum import Enum +from typing import Literal + +from cattrs import BaseConverter +from cattrs.fns import identity + + +class TestEnum(Enum): + TEST = "test" + + +def test_unstructure_literal(converter: BaseConverter): + """Literals without enums are passed through by default.""" + assert converter.get_unstructure_hook(1, Literal[1]) == identity + + +def test_unstructure_literal_with_enum(converter: BaseConverter): + """Literals with enums are properly unstructured.""" + assert converter.unstructure(TestEnum.TEST, Literal[TestEnum.TEST]) == "test" From 877dc4ca0ced2f1326b5b2d90231e8bf1f03ffe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tin=20Tvrtkovi=C4=87?= Date: Fri, 8 Nov 2024 22:37:03 +0100 Subject: [PATCH 2/3] preconf: faster enum handling --- HISTORY.md | 4 + docs/preconf.md | 13 +- src/cattrs/preconf/__init__.py | 11 ++ src/cattrs/preconf/bson.py | 8 +- src/cattrs/preconf/cbor2.py | 5 +- src/cattrs/preconf/json.py | 8 +- src/cattrs/preconf/msgpack.py | 8 +- src/cattrs/preconf/msgspec.py | 12 +- src/cattrs/preconf/orjson.py | 8 +- src/cattrs/preconf/ujson.py | 11 +- tests/test_preconf.py | 217 +++++++++++++++++++++++++++++++-- 11 files changed, 279 insertions(+), 26 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index d81df8ab..669f53d3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -22,7 +22,11 @@ Our backwards-compatibility policy can be found [here](https://github.com/python - Some `defaultdicts` are now [supported by default](https://catt.rs/en/latest/defaulthooks.html#defaultdicts), and {func}`cattrs.cols.is_defaultdict`{func} and `cattrs.cols.defaultdict_structure_factory` are exposed through {mod}`cattrs.cols`. ([#519](https://github.com/python-attrs/cattrs/issues/519) [#588](https://github.com/python-attrs/cattrs/pull/588)) +- Many preconf converters (_bson_, stdlib JSON, _cbor2_, _msgpack_, _msgspec_, _orjson_, _ujson_) skip unstructuring `int` and `str` enums, + leaving them to the underlying libraries to handle with greater efficiency. + ([#598](https://github.com/python-attrs/cattrs/pull/598)) - Literals containing enums are now unstructured properly. + ([#598](https://github.com/python-attrs/cattrs/pull/598)) - Replace `cattrs.gen.MappingStructureFn` with `cattrs.SimpleStructureHook[In, T]`. - Python 3.13 is now supported. ([#543](https://github.com/python-attrs/cattrs/pull/543) [#547](https://github.com/python-attrs/cattrs/issues/547)) diff --git a/docs/preconf.md b/docs/preconf.md index 4a3038a9..c76b22e2 100644 --- a/docs/preconf.md +++ b/docs/preconf.md @@ -2,17 +2,26 @@ The {mod}`cattrs.preconf` package contains factories for preconfigured converters, specifically adjusted for particular serialization libraries. -For example, to get a converter configured for BSON: +For example, to get a converter configured for _orjson_: ```{doctest} ->>> from cattrs.preconf.bson import make_converter +>>> from cattrs.preconf.orjson import make_converter >>> converter = make_converter() # Takes the same parameters as the `cattrs.Converter` ``` Converters obtained this way can be customized further, just like any other converter. +For compatibility and performance reasons, these converters are usually configured to unstructure differently than ordinary `Converters`. +A couple of examples: +* the {class}`_orjson_ converter ` is configured to pass `datetime` instances unstructured since _orjson_ can handle them faster. +* the {class}`_msgspec_ JSON converter ` is configured to pass through some dataclasses and _attrs_classes, +if the output is identical to what normal unstructuring would have produced, since _msgspec_ can handle them faster. + +The intended usage is to pass the unstructured output directly to the underlying library, +or use `converter.dumps` which will do it for you. + These converters support all [default hooks](defaulthooks.md) and the following additional classes and type annotations, both for structuring and unstructuring: diff --git a/src/cattrs/preconf/__init__.py b/src/cattrs/preconf/__init__.py index 876576d1..ec1dfc31 100644 --- a/src/cattrs/preconf/__init__.py +++ b/src/cattrs/preconf/__init__.py @@ -1,7 +1,10 @@ import sys from datetime import datetime +from enum import Enum from typing import Any, Callable, TypeVar +from .._compat import is_subclass + if sys.version_info[:2] < (3, 10): from typing_extensions import ParamSpec else: @@ -25,3 +28,11 @@ def impl(x: Callable[..., T]) -> Callable[P, T]: return x return impl + + +def is_primitive_enum(type: Any, include_bare_enums: bool = False) -> bool: + """Is this a string or int enum that can be passed through?""" + return is_subclass(type, Enum) and ( + is_subclass(type, (str, int)) + or (include_bare_enums and type.mro()[1:] == Enum.mro()) + ) diff --git a/src/cattrs/preconf/bson.py b/src/cattrs/preconf/bson.py index 0d8f5c65..55307cab 100644 --- a/src/cattrs/preconf/bson.py +++ b/src/cattrs/preconf/bson.py @@ -11,8 +11,9 @@ from ..converters import BaseConverter, Converter from ..dispatch import StructureHook +from ..fns import identity from ..strategies import configure_union_passthrough -from . import validate_datetime, wrap +from . import is_primitive_enum, validate_datetime, wrap T = TypeVar("T") @@ -52,6 +53,10 @@ def configure_converter(converter: BaseConverter): * byte mapping keys are base85-encoded into strings when unstructuring, and reverse * non-string, non-byte mapping keys are coerced into strings when unstructuring * a deserialization hook is registered for bson.ObjectId by default + * string and int enums are passed through when unstructuring + + .. versionchanged: 24.2.0 + Enums are left to the library to unstructure, speeding them up. """ def gen_unstructure_mapping(cl: Any, unstructure_to=None): @@ -92,6 +97,7 @@ def gen_structure_mapping(cl: Any) -> StructureHook: converter.register_structure_hook(datetime, validate_datetime) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + converter.register_unstructure_hook_func(is_primitive_enum, identity) @wrap(BsonConverter) diff --git a/src/cattrs/preconf/cbor2.py b/src/cattrs/preconf/cbor2.py index 63600c6a..a963cb70 100644 --- a/src/cattrs/preconf/cbor2.py +++ b/src/cattrs/preconf/cbor2.py @@ -8,8 +8,9 @@ from cattrs._compat import AbstractSet from ..converters import BaseConverter, Converter +from ..fns import identity from ..strategies import configure_union_passthrough -from . import wrap +from . import is_primitive_enum, wrap T = TypeVar("T") @@ -28,6 +29,7 @@ def configure_converter(converter: BaseConverter): * datetimes are serialized as timestamp floats * sets are serialized as lists + * string and int enums are passed through when unstructuring """ converter.register_unstructure_hook(datetime, lambda v: v.timestamp()) converter.register_structure_hook( @@ -35,6 +37,7 @@ def configure_converter(converter: BaseConverter): ) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + converter.register_unstructure_hook_func(is_primitive_enum, identity) configure_union_passthrough(Union[str, bool, int, float, None, bytes], converter) diff --git a/src/cattrs/preconf/json.py b/src/cattrs/preconf/json.py index 85e0cbc9..cdfa5942 100644 --- a/src/cattrs/preconf/json.py +++ b/src/cattrs/preconf/json.py @@ -7,8 +7,9 @@ from .._compat import AbstractSet, Counter from ..converters import BaseConverter, Converter +from ..fns import identity from ..strategies import configure_union_passthrough -from . import wrap +from . import is_primitive_enum, wrap T = TypeVar("T") @@ -29,8 +30,12 @@ def configure_converter(converter: BaseConverter): * datetimes are serialized as ISO 8601 * counters are serialized as dicts * sets are serialized as lists + * string and int enums are passed through when unstructuring * union passthrough is configured for unions of strings, bools, ints, floats and None + + .. versionchanged: 24.2.0 + Enums are left to the library to unstructure, speeding them up. """ converter.register_unstructure_hook( bytes, lambda v: (b85encode(v) if v else b"").decode("utf8") @@ -40,6 +45,7 @@ def configure_converter(converter: BaseConverter): converter.register_structure_hook(datetime, lambda v, _: datetime.fromisoformat(v)) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + converter.register_unstructure_hook_func(is_primitive_enum, identity) configure_union_passthrough(Union[str, bool, int, float, None], converter) diff --git a/src/cattrs/preconf/msgpack.py b/src/cattrs/preconf/msgpack.py index 530c3b54..2a27cace 100644 --- a/src/cattrs/preconf/msgpack.py +++ b/src/cattrs/preconf/msgpack.py @@ -8,8 +8,9 @@ from cattrs._compat import AbstractSet from ..converters import BaseConverter, Converter +from ..fns import identity from ..strategies import configure_union_passthrough -from . import wrap +from . import is_primitive_enum, wrap T = TypeVar("T") @@ -28,6 +29,10 @@ def configure_converter(converter: BaseConverter): * datetimes are serialized as timestamp floats * sets are serialized as lists + * string and int enums are passed through when unstructuring + + .. versionchanged: 24.2.0 + Enums are left to the library to unstructure, speeding them up. """ converter.register_unstructure_hook(datetime, lambda v: v.timestamp()) converter.register_structure_hook( @@ -39,6 +44,7 @@ def configure_converter(converter: BaseConverter): converter.register_structure_hook( date, lambda v, _: datetime.fromtimestamp(v, timezone.utc).date() ) + converter.register_unstructure_hook_func(is_primitive_enum, identity) configure_union_passthrough(Union[str, bool, int, float, None, bytes], converter) diff --git a/src/cattrs/preconf/msgspec.py b/src/cattrs/preconf/msgspec.py index 6ef84d76..99e063e5 100644 --- a/src/cattrs/preconf/msgspec.py +++ b/src/cattrs/preconf/msgspec.py @@ -72,11 +72,15 @@ def configure_converter(converter: Converter) -> None: * datetimes and dates are passed through to be serialized as RFC 3339 directly * enums are passed through to msgspec directly * union passthrough configured for str, bool, int, float and None + * bare, string and int enums are passed through when unstructuring + + .. versionchanged: 24.2.0 + Enums are left to the library to unstructure, speeding them up. """ configure_passthroughs(converter) converter.register_unstructure_hook(Struct, to_builtins) - converter.register_unstructure_hook(Enum, to_builtins) + converter.register_unstructure_hook(Enum, identity) converter.register_structure_hook(Struct, convert) converter.register_structure_hook(bytes, lambda v, _: b64decode(v)) @@ -100,7 +104,7 @@ def configure_passthroughs(converter: Converter) -> None: converter.register_unstructure_hook(bytes, to_builtins) converter.register_unstructure_hook_factory(is_mapping, mapping_unstructure_factory) converter.register_unstructure_hook_factory(is_sequence, seq_unstructure_factory) - converter.register_unstructure_hook_factory(has, attrs_unstructure_factory) + converter.register_unstructure_hook_factory(has, msgspec_attrs_unstructure_factory) converter.register_unstructure_hook_factory( is_namedtuple, namedtuple_unstructure_factory ) @@ -145,7 +149,9 @@ def mapping_unstructure_factory(type, converter: BaseConverter) -> UnstructureHo return converter.gen_unstructure_mapping(type) -def attrs_unstructure_factory(type: Any, converter: Converter) -> UnstructureHook: +def msgspec_attrs_unstructure_factory( + type: Any, converter: Converter +) -> UnstructureHook: """Choose whether to use msgspec handling or our own.""" origin = get_origin(type) attribs = fields(origin or type) diff --git a/src/cattrs/preconf/orjson.py b/src/cattrs/preconf/orjson.py index 1594ce6c..6518fbb6 100644 --- a/src/cattrs/preconf/orjson.py +++ b/src/cattrs/preconf/orjson.py @@ -13,7 +13,7 @@ from ..converters import BaseConverter, Converter from ..fns import identity from ..strategies import configure_union_passthrough -from . import wrap +from . import is_primitive_enum, wrap T = TypeVar("T") @@ -36,9 +36,12 @@ def configure_converter(converter: BaseConverter): * sets are serialized as lists * string enum mapping keys have special handling * mapping keys are coerced into strings when unstructuring + * bare, string and int enums are passed through when unstructuring .. versionchanged: 24.1.0 Add support for typed namedtuples. + .. versionchanged: 24.2.0 + Enums are left to the library to unstructure, speeding them up. """ converter.register_unstructure_hook( bytes, lambda v: (b85encode(v) if v else b"").decode("utf8") @@ -80,6 +83,9 @@ def key_handler(v): ), ] ) + converter.register_unstructure_hook_func( + partial(is_primitive_enum, include_bare_enums=True), identity + ) configure_union_passthrough(Union[str, bool, int, float, None], converter) diff --git a/src/cattrs/preconf/ujson.py b/src/cattrs/preconf/ujson.py index c5906d21..f4168fca 100644 --- a/src/cattrs/preconf/ujson.py +++ b/src/cattrs/preconf/ujson.py @@ -6,11 +6,11 @@ from ujson import dumps, loads -from cattrs._compat import AbstractSet - +from .._compat import AbstractSet from ..converters import BaseConverter, Converter +from ..fns import identity from ..strategies import configure_union_passthrough -from . import wrap +from . import is_primitive_enum, wrap T = TypeVar("T") @@ -30,6 +30,10 @@ def configure_converter(converter: BaseConverter): * bytes are serialized as base64 strings * datetimes are serialized as ISO 8601 * sets are serialized as lists + * string and int enums are passed through when unstructuring + + .. versionchanged: 24.2.0 + Enums are left to the library to unstructure, speeding them up. """ converter.register_unstructure_hook( bytes, lambda v: (b85encode(v) if v else b"").decode("utf8") @@ -40,6 +44,7 @@ def configure_converter(converter: BaseConverter): converter.register_structure_hook(datetime, lambda v, _: datetime.fromisoformat(v)) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + converter.register_unstructure_hook_func(is_primitive_enum, identity) configure_union_passthrough(Union[str, bool, int, float, None], converter) diff --git a/tests/test_preconf.py b/tests/test_preconf.py index 9199c371..e4d35f98 100644 --- a/tests/test_preconf.py +++ b/tests/test_preconf.py @@ -43,6 +43,7 @@ Set, TupleSubscriptable, ) +from cattrs.fns import identity from cattrs.preconf.bson import make_converter as bson_make_converter from cattrs.preconf.cbor2 import make_converter as cbor2_make_converter from cattrs.preconf.json import make_converter as json_make_converter @@ -50,6 +51,9 @@ from cattrs.preconf.tomlkit import make_converter as tomlkit_make_converter from cattrs.preconf.ujson import make_converter as ujson_make_converter +NO_MSGSPEC: Final = python_implementation() == "PyPy" or sys.version_info[:2] >= (3, 13) +NO_ORJSON: Final = python_implementation() == "PyPy" + @define class A: @@ -75,6 +79,9 @@ class AnIntEnum(IntEnum): class AStringEnum(str, Enum): A = "a" + class ABareEnum(Enum): + B = "b" + string: str bytes: bytes an_int: int @@ -93,6 +100,7 @@ class AStringEnum(str, Enum): a_frozenset: FrozenSet[str] an_int_enum: AnIntEnum a_str_enum: AStringEnum + a_bare_enum: ABareEnum a_datetime: datetime a_date: date a_string_enum_dict: Dict[AStringEnum, int] @@ -162,6 +170,7 @@ def everythings( draw(frozensets(strings)), Everything.AnIntEnum.A, Everything.AStringEnum.A, + Everything.ABareEnum.B, draw(dts), draw(dates(min_value=date(1970, 1, 1), max_value=date(2038, 1, 1))), draw(dictionaries(just(Everything.AStringEnum.A), ints)), @@ -325,6 +334,31 @@ def test_stdlib_json_unions_with_spillover( assert converter.structure(converter.unstructure(val), type) == val +def test_stdlib_json_native_enums(): + """Bare, string and int enums are handled correctly.""" + converter = json_make_converter() + assert ( + json_loads(converter.dumps(Everything.AnIntEnum.A)) + == Everything.AnIntEnum.A.value + ) + assert ( + json_loads(converter.dumps(Everything.AStringEnum.A)) + == Everything.AStringEnum.A.value + ) + assert ( + json_loads(converter.dumps(Everything.ABareEnum.B)) + == Everything.ABareEnum.B.value + ) + + +def test_stdlib_json_efficient_enum(): + """`str` and `int` enums are handled efficiently.""" + converter = json_make_converter() + + assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity + assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + + @given( everythings( min_int=-9223372036854775808, max_int=9223372036854775807, allow_inf=False @@ -377,7 +411,32 @@ def test_ujson_unions(union_and_val: tuple, detailed_validation: bool): assert converter.structure(val, type) == val -@pytest.mark.skipif(python_implementation() == "PyPy", reason="no orjson on PyPy") +def test_ujson_native_enums(): + """Bare, string and int enums are handled correctly.""" + converter = ujson_make_converter() + assert ( + json_loads(converter.dumps(Everything.AnIntEnum.A)) + == Everything.AnIntEnum.A.value + ) + assert ( + json_loads(converter.dumps(Everything.AStringEnum.A)) + == Everything.AStringEnum.A.value + ) + assert ( + json_loads(converter.dumps(Everything.ABareEnum.B)) + == Everything.ABareEnum.B.value + ) + + +def test_ujson_efficient_enum(): + """Bare, `str` and `int` enums are handled efficiently.""" + converter = ujson_make_converter() + + assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity + assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + + +@pytest.mark.skipif(NO_ORJSON, reason="orjson not available") @given( everythings( min_int=-9223372036854775808, max_int=9223372036854775807, allow_inf=False @@ -395,7 +454,7 @@ def test_orjson(everything: Everything, detailed_validation: bool): assert converter.structure(orjson_loads(raw), Everything) == everything -@pytest.mark.skipif(python_implementation() == "PyPy", reason="no orjson on PyPy") +@pytest.mark.skipif(NO_ORJSON, reason="orjson not available") @given( everythings( min_int=-9223372036854775808, max_int=9223372036854775807, allow_inf=False @@ -410,7 +469,7 @@ def test_orjson_converter(everything: Everything, detailed_validation: bool): assert converter.loads(raw, Everything) == everything -@pytest.mark.skipif(python_implementation() == "PyPy", reason="no orjson on PyPy") +@pytest.mark.skipif(NO_ORJSON, reason="orjson not available") @given( everythings( min_int=-9223372036854775808, max_int=9223372036854775807, allow_inf=False @@ -428,7 +487,7 @@ def test_orjson_converter_unstruct_collection_overrides(everything: Everything): assert raw["a_frozenset"] == sorted(raw["a_frozenset"]) -@pytest.mark.skipif(python_implementation() == "PyPy", reason="no orjson on PyPy") +@pytest.mark.skipif(NO_ORJSON, reason="orjson not available") @given( union_and_val=native_unions(include_bytes=False, include_datetimes=False), detailed_validation=..., @@ -443,6 +502,39 @@ def test_orjson_unions(union_and_val: tuple, detailed_validation: bool): assert converter.structure(val, type) == val +@pytest.mark.skipif(NO_ORJSON, reason="orjson not available") +def test_orjson_native_enums(): + """Bare, string and int enums are handled correctly.""" + from cattrs.preconf.orjson import make_converter as orjson_make_converter + + converter = orjson_make_converter() + + assert ( + json_loads(converter.dumps(Everything.AnIntEnum.A)) + == Everything.AnIntEnum.A.value + ) + assert ( + json_loads(converter.dumps(Everything.AStringEnum.A)) + == Everything.AStringEnum.A.value + ) + assert ( + json_loads(converter.dumps(Everything.ABareEnum.B)) + == Everything.ABareEnum.B.value + ) + + +@pytest.mark.skipif(NO_ORJSON, reason="orjson not available") +def test_orjson_efficient_enum(): + """Bare, `str` and `int` enums are handled efficiently.""" + from cattrs.preconf.orjson import make_converter as orjson_make_converter + + converter = orjson_make_converter() + + assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity + assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + assert converter.get_unstructure_hook(Everything.ABareEnum) == identity + + @given(everythings(min_int=-9223372036854775808, max_int=18446744073709551615)) def test_msgpack(everything: Everything): from msgpack import dumps as msgpack_dumps @@ -483,6 +575,30 @@ def test_msgpack_unions(union_and_val: tuple, detailed_validation: bool): assert converter.structure(val, type) == val +def test_msgpack_native_enums(): + """Bare, string and int enums are handled correctly.""" + + converter = msgpack_make_converter() + + assert converter.dumps(Everything.AnIntEnum.A) == converter.dumps( + Everything.AnIntEnum.A.value + ) + assert converter.dumps(Everything.AStringEnum.A) == converter.dumps( + Everything.AStringEnum.A.value + ) + assert converter.dumps(Everything.ABareEnum.B) == converter.dumps( + Everything.ABareEnum.B.value + ) + + +def test_msgpack_efficient_enum(): + """`str` and `int` enums are handled efficiently.""" + converter = msgpack_make_converter() + + assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity + assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + + @given( everythings( min_int=-9223372036854775808, @@ -551,6 +667,38 @@ def test_bson_unions(union_and_val: tuple, detailed_validation: bool): assert converter.structure(val, type) == val +def test_bson_objectid(): + """BSON ObjectIds are supported by default.""" + converter = bson_make_converter() + o = ObjectId() + assert o == converter.structure(str(o), ObjectId) + assert o == converter.structure(o, ObjectId) + + +def test_bson_native_enums(): + """Bare, string and int enums are handled correctly.""" + + converter = bson_make_converter() + + assert converter.dumps({"a": Everything.AnIntEnum.A}) == converter.dumps( + {"a": Everything.AnIntEnum.A.value} + ) + assert converter.dumps({"a": Everything.AStringEnum.A}) == converter.dumps( + {"a": Everything.AStringEnum.A.value} + ) + assert converter.dumps({"a": Everything.ABareEnum.B}) == converter.dumps( + {"a": Everything.ABareEnum.B.value} + ) + + +def test_bson_efficient_enum(): + """`str` and `int` enums are handled efficiently.""" + converter = bson_make_converter() + + assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity + assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + + @given( everythings( min_key_length=1, @@ -617,14 +765,6 @@ def test_tomlkit_unions(union_and_val: tuple, detailed_validation: bool): assert converter.structure(val, type) == val -def test_bson_objectid(): - """BSON ObjectIds are supported by default.""" - converter = bson_make_converter() - o = ObjectId() - assert o == converter.structure(str(o), ObjectId) - assert o == converter.structure(o, ObjectId) - - @given(everythings(min_int=-9223372036854775808, max_int=18446744073709551615)) def test_cbor2(everything: Everything): from cbor2 import dumps as cbor2_dumps @@ -662,7 +802,28 @@ def test_cbor2_unions(union_and_val: tuple, detailed_validation: bool): assert converter.structure(val, type) == val -NO_MSGSPEC: Final = python_implementation() == "PyPy" or sys.version_info[:2] >= (3, 13) +def test_cbor2_native_enums(): + """Bare, string and int enums are handled correctly.""" + + converter = cbor2_make_converter() + + assert converter.dumps(Everything.AnIntEnum.A) == converter.dumps( + Everything.AnIntEnum.A.value + ) + assert converter.dumps(Everything.AStringEnum.A) == converter.dumps( + Everything.AStringEnum.A.value + ) + assert converter.dumps(Everything.ABareEnum.B) == converter.dumps( + Everything.ABareEnum.B.value + ) + + +def test_cbor2_efficient_enum(): + """`str` and `int` enums are handled efficiently.""" + converter = cbor2_make_converter() + + assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity + assert converter.get_unstructure_hook(Everything.AStringEnum) == identity @pytest.mark.skipif(NO_MSGSPEC, reason="msgspec not available") @@ -703,3 +864,33 @@ def test_msgspec_json_unions(union_and_val: tuple, detailed_validation: bool): type, val = union_and_val assert converter.structure(val, type) == val + + +@pytest.mark.skipif(NO_MSGSPEC, reason="msgspec not available") +def test_msgspec_native_enums(): + """Bare, string and int enums are handled correctly.""" + from cattrs.preconf.msgspec import make_converter as msgspec_make_converter + + converter = msgspec_make_converter() + + assert converter.dumps(Everything.AnIntEnum.A) == converter.dumps( + Everything.AnIntEnum.A.value + ) + assert converter.dumps(Everything.AStringEnum.A) == converter.dumps( + Everything.AStringEnum.A.value + ) + assert converter.dumps(Everything.ABareEnum.B) == converter.dumps( + Everything.ABareEnum.B.value + ) + + +@pytest.mark.skipif(NO_MSGSPEC, reason="msgspec not available") +def test_msgspec_efficient_enum(): + """Bare, `str` and `int` enums are handled efficiently.""" + from cattrs.preconf.msgspec import make_converter as msgspec_make_converter + + converter = msgspec_make_converter() + + assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity + assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + assert converter.get_unstructure_hook(Everything.ABareEnum) == identity From bf3cda107f8ef50a0a8d0badde10679e59ca1b5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tin=20Tvrtkovi=C4=87?= Date: Mon, 11 Nov 2024 23:27:18 +0100 Subject: [PATCH 3/3] Optimize literals with enums --- HISTORY.md | 2 +- src/cattrs/preconf/__init__.py | 19 ++++++++++++++++++- src/cattrs/preconf/bson.py | 11 ++++++++++- src/cattrs/preconf/cbor2.py | 6 +++++- src/cattrs/preconf/json.py | 6 +++++- src/cattrs/preconf/msgpack.py | 6 +++++- src/cattrs/preconf/msgspec.py | 6 +++++- src/cattrs/preconf/orjson.py | 6 +++++- src/cattrs/preconf/ujson.py | 6 +++++- tests/test_literals.py | 4 ++-- tests/test_preconf.py | 23 +++++++++++++++++++++-- 11 files changed, 82 insertions(+), 13 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 669f53d3..02dfc8c0 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -25,7 +25,7 @@ Our backwards-compatibility policy can be found [here](https://github.com/python - Many preconf converters (_bson_, stdlib JSON, _cbor2_, _msgpack_, _msgspec_, _orjson_, _ujson_) skip unstructuring `int` and `str` enums, leaving them to the underlying libraries to handle with greater efficiency. ([#598](https://github.com/python-attrs/cattrs/pull/598)) -- Literals containing enums are now unstructured properly. +- Literals containing enums are now unstructured properly, and their unstructuring is greatly optimized in the _bson_, stdlib JSON, _cbor2_, _msgpack_, _msgspec_, _orjson_ and _ujson_ preconf converters. ([#598](https://github.com/python-attrs/cattrs/pull/598)) - Replace `cattrs.gen.MappingStructureFn` with `cattrs.SimpleStructureHook[In, T]`. - Python 3.13 is now supported. diff --git a/src/cattrs/preconf/__init__.py b/src/cattrs/preconf/__init__.py index ec1dfc31..1b12ef93 100644 --- a/src/cattrs/preconf/__init__.py +++ b/src/cattrs/preconf/__init__.py @@ -1,9 +1,11 @@ import sys from datetime import datetime from enum import Enum -from typing import Any, Callable, TypeVar +from typing import Any, Callable, TypeVar, get_args from .._compat import is_subclass +from ..converters import Converter, UnstructureHook +from ..fns import identity if sys.version_info[:2] < (3, 10): from typing_extensions import ParamSpec @@ -36,3 +38,18 @@ def is_primitive_enum(type: Any, include_bare_enums: bool = False) -> bool: is_subclass(type, (str, int)) or (include_bare_enums and type.mro()[1:] == Enum.mro()) ) + + +def literals_with_enums_unstructure_factory( + typ: Any, converter: Converter +) -> UnstructureHook: + """An unstructure hook factory for literals containing enums. + + If all contained enums can be passed through (their unstructure hook is `identity`), + the entire literal can also be passed through. + """ + if all( + converter.get_unstructure_hook(type(arg)) == identity for arg in get_args(typ) + ): + return identity + return converter.unstructure diff --git a/src/cattrs/preconf/bson.py b/src/cattrs/preconf/bson.py index 55307cab..ed6e361d 100644 --- a/src/cattrs/preconf/bson.py +++ b/src/cattrs/preconf/bson.py @@ -12,8 +12,14 @@ from ..converters import BaseConverter, Converter from ..dispatch import StructureHook from ..fns import identity +from ..literals import is_literal_containing_enums from ..strategies import configure_union_passthrough -from . import is_primitive_enum, validate_datetime, wrap +from . import ( + is_primitive_enum, + literals_with_enums_unstructure_factory, + validate_datetime, + wrap, +) T = TypeVar("T") @@ -98,6 +104,9 @@ def gen_structure_mapping(cl: Any) -> StructureHook: converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) converter.register_unstructure_hook_func(is_primitive_enum, identity) + converter.register_unstructure_hook_factory( + is_literal_containing_enums, literals_with_enums_unstructure_factory + ) @wrap(BsonConverter) diff --git a/src/cattrs/preconf/cbor2.py b/src/cattrs/preconf/cbor2.py index a963cb70..13e224ef 100644 --- a/src/cattrs/preconf/cbor2.py +++ b/src/cattrs/preconf/cbor2.py @@ -9,8 +9,9 @@ from ..converters import BaseConverter, Converter from ..fns import identity +from ..literals import is_literal_containing_enums from ..strategies import configure_union_passthrough -from . import is_primitive_enum, wrap +from . import is_primitive_enum, literals_with_enums_unstructure_factory, wrap T = TypeVar("T") @@ -38,6 +39,9 @@ def configure_converter(converter: BaseConverter): converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) converter.register_unstructure_hook_func(is_primitive_enum, identity) + converter.register_unstructure_hook_factory( + is_literal_containing_enums, literals_with_enums_unstructure_factory + ) configure_union_passthrough(Union[str, bool, int, float, None, bytes], converter) diff --git a/src/cattrs/preconf/json.py b/src/cattrs/preconf/json.py index cdfa5942..2865326f 100644 --- a/src/cattrs/preconf/json.py +++ b/src/cattrs/preconf/json.py @@ -8,8 +8,9 @@ from .._compat import AbstractSet, Counter from ..converters import BaseConverter, Converter from ..fns import identity +from ..literals import is_literal_containing_enums from ..strategies import configure_union_passthrough -from . import is_primitive_enum, wrap +from . import is_primitive_enum, literals_with_enums_unstructure_factory, wrap T = TypeVar("T") @@ -45,6 +46,9 @@ def configure_converter(converter: BaseConverter): converter.register_structure_hook(datetime, lambda v, _: datetime.fromisoformat(v)) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + converter.register_unstructure_hook_factory( + is_literal_containing_enums, literals_with_enums_unstructure_factory + ) converter.register_unstructure_hook_func(is_primitive_enum, identity) configure_union_passthrough(Union[str, bool, int, float, None], converter) diff --git a/src/cattrs/preconf/msgpack.py b/src/cattrs/preconf/msgpack.py index 2a27cace..9549dfcb 100644 --- a/src/cattrs/preconf/msgpack.py +++ b/src/cattrs/preconf/msgpack.py @@ -9,8 +9,9 @@ from ..converters import BaseConverter, Converter from ..fns import identity +from ..literals import is_literal_containing_enums from ..strategies import configure_union_passthrough -from . import is_primitive_enum, wrap +from . import is_primitive_enum, literals_with_enums_unstructure_factory, wrap T = TypeVar("T") @@ -45,6 +46,9 @@ def configure_converter(converter: BaseConverter): date, lambda v, _: datetime.fromtimestamp(v, timezone.utc).date() ) converter.register_unstructure_hook_func(is_primitive_enum, identity) + converter.register_unstructure_hook_factory( + is_literal_containing_enums, literals_with_enums_unstructure_factory + ) configure_union_passthrough(Union[str, bool, int, float, None, bytes], converter) diff --git a/src/cattrs/preconf/msgspec.py b/src/cattrs/preconf/msgspec.py index 99e063e5..62673c27 100644 --- a/src/cattrs/preconf/msgspec.py +++ b/src/cattrs/preconf/msgspec.py @@ -27,8 +27,9 @@ from ..dispatch import UnstructureHook from ..fns import identity from ..gen import make_hetero_tuple_unstructure_fn +from ..literals import is_literal_containing_enums from ..strategies import configure_union_passthrough -from . import wrap +from . import literals_with_enums_unstructure_factory, wrap T = TypeVar("T") @@ -86,6 +87,9 @@ def configure_converter(converter: Converter) -> None: converter.register_structure_hook(bytes, lambda v, _: b64decode(v)) converter.register_structure_hook(datetime, lambda v, _: convert(v, datetime)) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) + converter.register_unstructure_hook_factory( + is_literal_containing_enums, literals_with_enums_unstructure_factory + ) configure_union_passthrough(Union[str, bool, int, float, None], converter) diff --git a/src/cattrs/preconf/orjson.py b/src/cattrs/preconf/orjson.py index 6518fbb6..6609febd 100644 --- a/src/cattrs/preconf/orjson.py +++ b/src/cattrs/preconf/orjson.py @@ -12,8 +12,9 @@ from ..cols import is_namedtuple, namedtuple_unstructure_factory from ..converters import BaseConverter, Converter from ..fns import identity +from ..literals import is_literal_containing_enums from ..strategies import configure_union_passthrough -from . import is_primitive_enum, wrap +from . import is_primitive_enum, literals_with_enums_unstructure_factory, wrap T = TypeVar("T") @@ -86,6 +87,9 @@ def key_handler(v): converter.register_unstructure_hook_func( partial(is_primitive_enum, include_bare_enums=True), identity ) + converter.register_unstructure_hook_factory( + is_literal_containing_enums, literals_with_enums_unstructure_factory + ) configure_union_passthrough(Union[str, bool, int, float, None], converter) diff --git a/src/cattrs/preconf/ujson.py b/src/cattrs/preconf/ujson.py index f4168fca..0c7fec4e 100644 --- a/src/cattrs/preconf/ujson.py +++ b/src/cattrs/preconf/ujson.py @@ -9,8 +9,9 @@ from .._compat import AbstractSet from ..converters import BaseConverter, Converter from ..fns import identity +from ..literals import is_literal_containing_enums from ..strategies import configure_union_passthrough -from . import is_primitive_enum, wrap +from . import is_primitive_enum, literals_with_enums_unstructure_factory, wrap T = TypeVar("T") @@ -45,6 +46,9 @@ def configure_converter(converter: BaseConverter): converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) converter.register_unstructure_hook_func(is_primitive_enum, identity) + converter.register_unstructure_hook_factory( + is_literal_containing_enums, literals_with_enums_unstructure_factory + ) configure_union_passthrough(Union[str, bool, int, float, None], converter) diff --git a/tests/test_literals.py b/tests/test_literals.py index 37317177..e9dbc9d8 100644 --- a/tests/test_literals.py +++ b/tests/test_literals.py @@ -5,7 +5,7 @@ from cattrs.fns import identity -class TestEnum(Enum): +class AnEnum(Enum): TEST = "test" @@ -16,4 +16,4 @@ def test_unstructure_literal(converter: BaseConverter): def test_unstructure_literal_with_enum(converter: BaseConverter): """Literals with enums are properly unstructured.""" - assert converter.unstructure(TestEnum.TEST, Literal[TestEnum.TEST]) == "test" + assert converter.unstructure(AnEnum.TEST, Literal[AnEnum.TEST]) == "test" diff --git a/tests/test_preconf.py b/tests/test_preconf.py index e4d35f98..2ab0b107 100644 --- a/tests/test_preconf.py +++ b/tests/test_preconf.py @@ -4,10 +4,10 @@ from json import dumps as json_dumps from json import loads as json_loads from platform import python_implementation -from typing import Any, Dict, Final, List, NamedTuple, NewType, Union +from typing import Any, Dict, Final, List, Literal, NamedTuple, NewType, Union import pytest -from attrs import define +from attrs import define, fields from bson import CodecOptions, ObjectId from hypothesis import given, settings from hypothesis.strategies import ( @@ -109,6 +109,8 @@ class ABareEnum(Enum): native_union_with_spillover: Union[int, str, Set[str]] native_union_with_union_spillover: Union[int, str, A, B] a_namedtuple: C + a_literal: Literal[1, AStringEnum.A] + a_literal_with_bare: Literal[1, ABareEnum.B] @composite @@ -179,6 +181,8 @@ def everythings( draw(one_of(ints, strings, sets(strings))), draw(one_of(ints, strings, ints.map(A), strings.map(B))), draw(fs.map(C)), + draw(one_of(just(1), just(Everything.AStringEnum.A))), + draw(one_of(just(1), just(Everything.ABareEnum.B))), ) @@ -357,6 +361,7 @@ def test_stdlib_json_efficient_enum(): assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + assert converter.get_unstructure_hook(fields(Everything).a_literal) == identity @given( @@ -434,6 +439,7 @@ def test_ujson_efficient_enum(): assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + assert converter.get_unstructure_hook(fields(Everything).a_literal.type) == identity @pytest.mark.skipif(NO_ORJSON, reason="orjson not available") @@ -533,6 +539,11 @@ def test_orjson_efficient_enum(): assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity assert converter.get_unstructure_hook(Everything.AStringEnum) == identity assert converter.get_unstructure_hook(Everything.ABareEnum) == identity + assert converter.get_unstructure_hook(fields(Everything).a_literal.type) == identity + assert ( + converter.get_unstructure_hook(fields(Everything).a_literal_with_bare.type) + == identity + ) @given(everythings(min_int=-9223372036854775808, max_int=18446744073709551615)) @@ -597,6 +608,7 @@ def test_msgpack_efficient_enum(): assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + assert converter.get_unstructure_hook(fields(Everything).a_literal.type) == identity @given( @@ -697,6 +709,7 @@ def test_bson_efficient_enum(): assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + assert converter.get_unstructure_hook(fields(Everything).a_literal.type) == identity @given( @@ -824,6 +837,7 @@ def test_cbor2_efficient_enum(): assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity assert converter.get_unstructure_hook(Everything.AStringEnum) == identity + assert converter.get_unstructure_hook(fields(Everything).a_literal.type) == identity @pytest.mark.skipif(NO_MSGSPEC, reason="msgspec not available") @@ -894,3 +908,8 @@ def test_msgspec_efficient_enum(): assert converter.get_unstructure_hook(Everything.AnIntEnum) == identity assert converter.get_unstructure_hook(Everything.AStringEnum) == identity assert converter.get_unstructure_hook(Everything.ABareEnum) == identity + assert converter.get_unstructure_hook(fields(Everything).a_literal.type) == identity + assert ( + converter.get_unstructure_hook(fields(Everything).a_literal_with_bare.type) + == identity + )