diff --git a/HISTORY.md b/HISTORY.md index 86e8db55..5f811e26 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -50,6 +50,8 @@ - _cattrs_ is now linted with [Ruff](https://beta.ruff.rs/docs/). - Remove some unused lines in the unstructuring code. ([#416](https://github.com/python-attrs/cattrs/pull/416)) +- Fix handling classes inheriting from non-generic protocols. + ([#374](https://github.com/python-attrs/cattrs/issues/374)) ## 23.1.2 (2023-06-02) diff --git a/Makefile b/Makefile index f89ec9fa..c1c930dd 100644 --- a/Makefile +++ b/Makefile @@ -49,6 +49,7 @@ clean-test: ## remove test and coverage artifacts lint: ## check style with ruff and black pdm run ruff src/ tests + pdm run isort -c src/ tests pdm run black --check src tests docs/conf.py test: ## run tests quickly with the default Python diff --git a/src/cattrs/_compat.py b/src/cattrs/_compat.py index d4a6a330..8cd437f3 100644 --- a/src/cattrs/_compat.py +++ b/src/cattrs/_compat.py @@ -316,8 +316,11 @@ def is_counter(type): ) def is_generic(obj) -> bool: - return isinstance(obj, (_GenericAlias, GenericAlias)) or is_subclass( - obj, Generic + """Whether obj is a generic type.""" + # Inheriting from protocol will inject `Generic` into the MRO + # without `__orig_bases__`. + return isinstance(obj, (_GenericAlias, GenericAlias)) or ( + is_subclass(obj, Generic) and hasattr(obj, "__orig_bases__") ) def copy_with(type, args): @@ -343,7 +346,7 @@ def get_full_type_hints(obj, globalns=None, localns=None): TupleSubscriptable = Tuple from collections import Counter as ColCounter - from typing import Counter, TypedDict, Union, _GenericAlias + from typing import Counter, Generic, TypedDict, Union, _GenericAlias from typing_extensions import Annotated, NotRequired, Required from typing_extensions import get_origin as te_get_origin @@ -429,7 +432,9 @@ def is_literal(type) -> bool: return type.__class__ is _GenericAlias and type.__origin__ is Literal def is_generic(obj): - return isinstance(obj, _GenericAlias) + return isinstance(obj, _GenericAlias) or ( + is_subclass(obj, Generic) and hasattr(obj, "__orig_bases__") + ) def copy_with(type, args): """Replace a generic type's arguments.""" diff --git a/src/cattrs/preconf/orjson.py b/src/cattrs/preconf/orjson.py index 0b4f32de..5cc74729 100644 --- a/src/cattrs/preconf/orjson.py +++ b/src/cattrs/preconf/orjson.py @@ -1,6 +1,6 @@ """Preconfigured converters for orjson.""" from base64 import b85decode, b85encode -from datetime import datetime, date +from datetime import date, datetime from enum import Enum from typing import Any, Type, TypeVar, Union diff --git a/tests/test_baseconverter.py b/tests/test_baseconverter.py index e4ac50fe..63057015 100644 --- a/tests/test_baseconverter.py +++ b/tests/test_baseconverter.py @@ -1,9 +1,8 @@ """Test both structuring and unstructuring.""" from typing import Optional, Union -import attr import pytest -from attr import define, fields, make_class +from attrs import define, fields, make_class from hypothesis import HealthCheck, assume, given, settings from hypothesis.strategies import just, one_of @@ -90,9 +89,9 @@ def test_union_field_roundtrip(cl_and_vals_a, cl_and_vals_b, strat): common_names = a_field_names & b_field_names assume(len(a_field_names) > len(common_names)) - @attr.s + @define class C: - a = attr.ib(type=Union[cl_a, cl_b]) + a: Union[cl_a, cl_b] inst = C(a=cl_a(*vals_a, **kwargs_a)) @@ -161,9 +160,9 @@ def test_optional_field_roundtrip(cl_and_vals): converter = BaseConverter() cl, vals, kwargs = cl_and_vals - @attr.s + @define class C: - a = attr.ib(type=Optional[cl]) + a: Optional[cl] inst = C(a=cl(*vals, **kwargs)) assert inst == converter.structure(converter.unstructure(inst), C) diff --git a/tests/test_gen.py b/tests/test_gen.py index 05fce839..cf3974b1 100644 --- a/tests/test_gen.py +++ b/tests/test_gen.py @@ -2,7 +2,7 @@ import linecache from traceback import format_exc -from attr import define +from attrs import define from cattrs import Converter from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn diff --git a/tests/test_generics.py b/tests/test_generics.py index 97e3233b..9d075187 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -2,7 +2,7 @@ from typing import Deque, Dict, Generic, List, Optional, TypeVar, Union import pytest -from attr import asdict, attrs, define +from attrs import asdict, define from cattrs import BaseConverter, Converter from cattrs._compat import Protocol @@ -132,7 +132,7 @@ def test_able_to_structure_deeply_nested_generics_gen(converter): def test_structure_unions_of_generics(converter): - @attrs(auto_attribs=True) + @define class TClass2(Generic[T]): c: T @@ -142,7 +142,7 @@ class TClass2(Generic[T]): def test_structure_list_of_generic_unions(converter): - @attrs(auto_attribs=True) + @define class TClass2(Generic[T]): c: T @@ -154,7 +154,7 @@ class TClass2(Generic[T]): def test_structure_deque_of_generic_unions(converter): - @attrs(auto_attribs=True) + @define class TClass2(Generic[T]): c: T @@ -179,35 +179,31 @@ def test_raises_if_no_generic_params_supplied( assert exc.value.type_ is T -def test_unstructure_generic_attrs(): - c = Converter() - - @attrs(auto_attribs=True) +def test_unstructure_generic_attrs(genconverter): + @define class Inner(Generic[T]): a: T - @attrs(auto_attribs=True) + @define class Outer: inner: Inner[int] initial = Outer(Inner(1)) - raw = c.unstructure(initial) + raw = genconverter.unstructure(initial) assert raw == {"inner": {"a": 1}} - new = c.structure(raw, Outer) + new = genconverter.structure(raw, Outer) assert initial == new - @attrs(auto_attribs=True) + @define class OuterStr: inner: Inner[str] - assert c.structure(raw, OuterStr) == OuterStr(Inner("1")) - + assert genconverter.structure(raw, OuterStr) == OuterStr(Inner("1")) -def test_unstructure_deeply_nested_generics(): - c = Converter() +def test_unstructure_deeply_nested_generics(genconverter): @define class Inner: a: int @@ -217,16 +213,14 @@ class Outer(Generic[T]): inner: T initial = Outer[Inner](Inner(1)) - raw = c.unstructure(initial, Outer[Inner]) + raw = genconverter.unstructure(initial, Outer[Inner]) assert raw == {"inner": {"a": 1}} - raw = c.unstructure(initial) + raw = genconverter.unstructure(initial) assert raw == {"inner": {"a": 1}} -def test_unstructure_deeply_nested_generics_list(): - c = Converter() - +def test_unstructure_deeply_nested_generics_list(genconverter): @define class Inner: a: int @@ -236,16 +230,14 @@ class Outer(Generic[T]): inner: List[T] initial = Outer[Inner]([Inner(1)]) - raw = c.unstructure(initial, Outer[Inner]) + raw = genconverter.unstructure(initial, Outer[Inner]) assert raw == {"inner": [{"a": 1}]} - raw = c.unstructure(initial) + raw = genconverter.unstructure(initial) assert raw == {"inner": [{"a": 1}]} -def test_unstructure_protocol(): - c = Converter() - +def test_unstructure_protocol(genconverter): class Proto(Protocol): a: int @@ -258,10 +250,10 @@ class Outer: inner: Proto initial = Outer(Inner(1)) - raw = c.unstructure(initial, Outer) + raw = genconverter.unstructure(initial, Outer) assert raw == {"inner": {"a": 1}} - raw = c.unstructure(initial) + raw = genconverter.unstructure(initial) assert raw == {"inner": {"a": 1}} @@ -306,3 +298,27 @@ class B(A[int]): pass assert generate_mapping(B, {}) == {T.__name__: int} + + +def test_nongeneric_protocols(converter): + """Non-generic protocols work.""" + + class NongenericProtocol(Protocol): + ... + + @define + class Entity(NongenericProtocol): + ... + + assert generate_mapping(Entity) == {} + + class GenericProtocol(Protocol[T]): + ... + + @define + class GenericEntity(GenericProtocol[int]): + a: int + + assert generate_mapping(GenericEntity) == {"T": int} + + assert converter.structure({"a": 1}, GenericEntity) == GenericEntity(1) diff --git a/tests/test_validation.py b/tests/test_validation.py index d472bb15..575bbf2f 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,8 +1,8 @@ """Tests for the extended validation mode.""" +import pickle from typing import Dict, FrozenSet, List, Set, Tuple import pytest -import pickle from attrs import define, field from attrs.validators import in_ from hypothesis import given